Layer integration (#83)

* integrated parallel layers for ease of building models

* integrated 2.5d layers

* cleaned codes and unit tests

* added log metric by step hook; updated imagenet benchmark; fixed some bugs

* reworked initialization; cleaned codes

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス 2021-12-27 15:04:32 +08:00 committed by GitHub
parent 5c3843dc98
commit 0fedef4f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
118 changed files with 4941 additions and 8116 deletions

66
benchmark/README.md Normal file
View File

@ -0,0 +1,66 @@
# Benchmark for Tuning Accuracy and Efficiency
## Overview
The benchmark includes our efforts in using Colossal-AI to train different tasks to achieve SOTA results.
We are interested in both validataion accuracy and training speed, and prefer larger batch size to take advantage of more GPU devices.
For example, we trained vision transformer with batch size 512 on CIFAR10 and 4096 on ImageNet1k, which are basically not used in existing works.
Some of the results in the benchmark trained with 8x A100 are shown below.
| Task | Model | Training Time | Top-1 Accuracy |
| ---------- | ------------ | ------------- | -------------- |
| CIFAR10 | [ViT-Lite-7/4](https://arxiv.org/pdf/2104.05704.pdf) | ~ 16 min | ~ 90.5% |
| ImageNet1k | ViT-S/16 | ~ 16.5 h | ~ 74.5% |
The `train.py` script in each task runs training with the specific configuration script in `configs/` for different parallelisms.
Supported parallelisms include data parallel only (ends with `vanilla`), 1D (ends with `1d`), 2D (ends with `2d`), 2.5D (ends with `2p5d`), 3D (ends with `3d`).
Each configuration scripts basically includes the following elements, taking ImageNet1k task as example:
```
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
# data parallel only
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
# parallelism setting
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, ) # amp setting
gradient_accumulation = 2 # accumulate 2 steps for gradient update
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation # actual batch size for dataloader
clip_grad_norm = 1.0 # clip gradient with norm 1.0
```
Upper case elements are basically what `train.py` needs, and lower case elements are what Colossal-AI needs to initialize the training.
## Usage
To start training, use the following command to run each worker:
```
$ DATA=/path/to/dataset python train.py --world_size=WORLD_SIZE \
--rank=RANK \
--local_rank=LOCAL_RANK \
--host=MASTER_IP_ADDRESS \
--port=MASTER_PORT \
--config=CONFIG_FILE
```
It is also recommended to start training with `torchrun` as:
```
$ DATA=/path/to/dataset torchrun --nproc_per_node=NUM_GPUS_PER_NODE \
--nnodes=NUM_NODES \
--node_rank=NODE_RANK \
--master_addr=MASTER_IP_ADDRESS \
--master_port=MASTER_PORT \
train.py --config=CONFIG_FILE
```

View File

@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
seed = 42
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"

View File

@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
seed = 42
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"

View File

@ -0,0 +1,19 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)
seed = 42
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"

View File

@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
seed = 42
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"

View File

@ -0,0 +1,18 @@
BATCH_SIZE = 512
LEARNING_RATE = 2e-3
WEIGHT_DECAY = 3e-2
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
NUM_EPOCHS = 200
WARMUP_EPOCHS = 40
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
seed = 42
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"

126
benchmark/cifar/train.py Normal file
View File

@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import colossalai
import torch
import torchvision
from colossalai.builder import *
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook, LossHook,
LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer, get_dataloader
from model_zoo.vit import vit_lite_depth7_patch4_32
from torchvision import transforms
DATASET_PATH = str(os.environ['DATA'])
def build_cifar(batch_size):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
train=True,
download=True,
transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=batch_size,
num_workers=4,
pin_memory=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
return train_dataloader, test_dataloader
def train_cifar():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
steps_per_epoch = len(train_dataloader)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
lr_scheduler=lr_scheduler)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hooks = [
LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hooks,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_cifar()

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,27 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,211 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import glob
import os
import colossalai
import nvidia.dali.fn as fn
import nvidia.dali.tfrecord as tfrec
import torch
from colossalai.builder import *
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
DATASET_PATH = str(os.environ['DATA'])
TRAIN_RECS = DATASET_PATH + '/train/*'
VAL_RECS = DATASET_PATH + '/validation/*'
TRAIN_IDX = DATASET_PATH + '/idx_files/train/*'
VAL_IDX = DATASET_PATH + '/idx_files/validation/*'
class DaliDataloader(DALIClassificationIterator):
def __init__(self,
tfrec_filenames,
tfrec_idx_filenames,
shard_id=0,
num_shards=1,
batch_size=128,
num_threads=4,
resize=256,
crop=224,
prefetch=2,
training=True,
gpu_aug=False,
cuda=True):
pipe = Pipeline(batch_size=batch_size,
num_threads=num_threads,
device_id=torch.cuda.current_device() if cuda else None,
seed=1024)
with pipe:
inputs = fn.readers.tfrecord(path=tfrec_filenames,
index_path=tfrec_idx_filenames,
random_shuffle=training,
shard_id=shard_id,
num_shards=num_shards,
initial_fill=10000,
read_ahead=True,
prefetch_queue_depth=prefetch,
name='Reader',
features={
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
})
images = inputs["image/encoded"]
if training:
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu')
flip_lr = fn.random.coin_flip(probability=0.5)
else:
# decode jpeg and resize
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.resize(images,
device='gpu' if gpu_aug else 'cpu',
resize_x=resize,
resize_y=resize,
dtype=types.FLOAT,
interp_type=types.INTERP_TRIANGULAR)
flip_lr = False
# center crop and normalise
images = fn.crop_mirror_normalize(images,
dtype=types.FLOAT,
crop=(crop, crop),
mean=[127.5],
std=[127.5],
mirror=flip_lr)
label = inputs["image/class/label"] - 1 # 0-999
# LSG: element_extract will raise exception, let's flatten outside
# label = fn.element_extract(label, element_map=0) # Flatten
if cuda: # transfer data to gpu
pipe.set_outputs(images.gpu(), label.gpu())
else:
pipe.set_outputs(images, label)
pipe.build()
last_batch_policy = 'DROP' if training else 'PARTIAL'
super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
def __iter__(self):
# if not reset (after an epoch), reset; if just initialize, ignore
if self._counter >= self._size or self._size < 0:
self.reset()
return self
def __next__(self):
data = super().__next__()
img, label = data[0]['data'], data[0]['label']
label = label.squeeze()
return (img, ), (label, )
def build_dali_train(batch_size):
return DaliDataloader(
sorted(glob.glob(TRAIN_RECS)),
sorted(glob.glob(TRAIN_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=True,
cuda=True,
)
def build_dali_test(batch_size):
return DaliDataloader(
sorted(glob.glob(VAL_RECS)),
sorted(glob.glob(VAL_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
gpu_aug=True,
cuda=True,
)
def train_imagenet():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hooks = [
LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hooks,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_imagenet()

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '1d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
TENSOR_PARALLEL_MODE = '2d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,27 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 4
DEPTH = 1
TENSOR_PARALLEL_MODE = '2.5d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 8
TENSOR_PARALLEL_MODE = '3d'
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,26 @@
from colossalai.amp import AMP_TYPE
TOTAL_BATCH_SIZE = 4096
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
TENSOR_PARALLEL_SIZE = 1
TENSOR_PARALLEL_MODE = None
NUM_EPOCHS = 300
WARMUP_EPOCHS = 32
parallel = dict(
pipeline=1,
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
)
fp16 = dict(mode=AMP_TYPE.TORCH, )
gradient_accumulation = 2
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
clip_grad_norm = 1.0
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"

View File

@ -0,0 +1,211 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import glob
import os
import colossalai
import nvidia.dali.fn as fn
import nvidia.dali.tfrecord as tfrec
import torch
from colossalai.builder import *
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
from colossalai.utils import MultiTimer
from model_zoo.vit import vit_small_patch16_224
from nvidia.dali import types
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
DATASET_PATH = str(os.environ['DATA'])
TRAIN_RECS = DATASET_PATH + '/train/*'
VAL_RECS = DATASET_PATH + '/validation/*'
TRAIN_IDX = DATASET_PATH + '/idx_files/train/*'
VAL_IDX = DATASET_PATH + '/idx_files/validation/*'
class DaliDataloader(DALIClassificationIterator):
def __init__(self,
tfrec_filenames,
tfrec_idx_filenames,
shard_id=0,
num_shards=1,
batch_size=128,
num_threads=4,
resize=256,
crop=224,
prefetch=2,
training=True,
gpu_aug=False,
cuda=True):
pipe = Pipeline(batch_size=batch_size,
num_threads=num_threads,
device_id=torch.cuda.current_device() if cuda else None,
seed=1024)
with pipe:
inputs = fn.readers.tfrecord(path=tfrec_filenames,
index_path=tfrec_idx_filenames,
random_shuffle=training,
shard_id=shard_id,
num_shards=num_shards,
initial_fill=10000,
read_ahead=True,
prefetch_queue_depth=prefetch,
name='Reader',
features={
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
})
images = inputs["image/encoded"]
if training:
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu')
flip_lr = fn.random.coin_flip(probability=0.5)
else:
# decode jpeg and resize
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
images = fn.resize(images,
device='gpu' if gpu_aug else 'cpu',
resize_x=resize,
resize_y=resize,
dtype=types.FLOAT,
interp_type=types.INTERP_TRIANGULAR)
flip_lr = False
# center crop and normalise
images = fn.crop_mirror_normalize(images,
dtype=types.FLOAT,
crop=(crop, crop),
mean=[127.5],
std=[127.5],
mirror=flip_lr)
label = inputs["image/class/label"] - 1 # 0-999
# LSG: element_extract will raise exception, let's flatten outside
# label = fn.element_extract(label, element_map=0) # Flatten
if cuda: # transfer data to gpu
pipe.set_outputs(images.gpu(), label.gpu())
else:
pipe.set_outputs(images, label)
pipe.build()
last_batch_policy = 'DROP' if training else 'PARTIAL'
super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
def __iter__(self):
# if not reset (after an epoch), reset; if just initialize, ignore
if self._counter >= self._size or self._size < 0:
self.reset()
return self
def __next__(self):
data = super().__next__()
img, label = data[0]['data'], data[0]['label']
label = label.squeeze()
return (img, ), (label, )
def build_dali_train(batch_size):
return DaliDataloader(
sorted(glob.glob(TRAIN_RECS)),
sorted(glob.glob(TRAIN_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=True,
gpu_aug=True,
cuda=True,
)
def build_dali_test(batch_size):
return DaliDataloader(
sorted(glob.glob(VAL_RECS)),
sorted(glob.glob(VAL_IDX)),
batch_size=batch_size,
shard_id=gpc.get_local_rank(ParallelMode.DATA),
num_shards=gpc.get_world_size(ParallelMode.DATA),
training=False,
gpu_aug=True,
cuda=True,
)
def train_imagenet():
args = colossalai.get_default_parser().parse_args()
# standard launch
# colossalai.launch(config=args.config,
# rank=args.rank,
# world_size=args.world_size,
# local_rank=args.local_rank,
# host=args.host,
# port=args.port)
# launch from torchrun
colossalai.launch_from_torch(config=args.config)
logger = get_dist_logger()
if hasattr(gpc.config, 'LOG_PATH'):
if gpc.get_global_rank() == 0:
log_path = gpc.config.LOG_PATH
if not os.path.exists(log_path):
os.mkdir(log_path)
logger.log_to_file(log_path)
tp = gpc.config.parallel.tensor.mode
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
total_steps=gpc.config.NUM_EPOCHS,
warmup_steps=gpc.config.WARMUP_EPOCHS)
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader)
logger.info("Engine is built", ranks=[0])
timer = MultiTimer()
trainer = Trainer(engine=engine, logger=logger, timer=timer)
logger.info("Trainer is built", ranks=[0])
hooks = [
LogMetricByEpochHook(logger=logger),
LogMetricByStepHook(),
# LogTimingByEpochHook(timer=timer, logger=logger),
# LogMemoryByEpochHook(logger=logger),
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
LossHook(),
ThroughputHook(),
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
]
logger.info("Train start", ranks=[0])
trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.NUM_EPOCHS,
hooks=hooks,
display_progress=True,
test_interval=1)
if __name__ == '__main__':
train_imagenet()

View File

@ -1,14 +1,17 @@
from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
from .p2p import (send_forward, send_forward_recv_forward,
send_backward_recv_forward, send_backward,
send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward,
recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [
'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce',
'send_forward', 'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward', 'send_backward',
'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
]

View File

@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor
from colossalai.context import ParallelMode
@ -10,8 +11,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode, async_op=False) -> Tensor:
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
@ -25,29 +25,31 @@ def all_gather(tensor: Tensor, dim: int,
:rtype: :class:`torch.Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone()
# shape = list(temp.shape)
# shape[dim] *= depth
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
# out = list(torch.chunk(out, depth, dim=dim))
# out = [val.contiguous() for val in out]
shape = [1] * len(tensor.shape)
shape[dim] = depth
out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if depth == 1:
out = [tensor]
work = None
else:
shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
temp = list(torch.chunk(out, depth, dim=0))
work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(),
group=gpc.get_group(parallel_mode),
async_op=async_op)
out = torch.transpose(out, 0, dim)
if async_op:
return out, op
return out, work
else:
return out
def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode, async_op=False) -> Tensor:
def reduce_scatter(tensor: Tensor,
dim: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
@ -61,52 +63,57 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: :class:`Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
# temp = list(torch.chunk(tensor, depth, dim=dim))
# temp = [val.contiguous() for val in temp]
# out = torch.zeros(temp[0].shape,
# dtype=temp[0].dtype,
# device=get_current_device())
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = temp[0].clone()
op = dist.reduce_scatter(output=out,
input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if depth == 1:
out = tensor
work = None
else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
work = dist.reduce_scatter(output=out,
input_list=temp,
op=op,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
return out, work
else:
return out
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
async_op=False) -> Tensor:
op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, op
return tensor, work
else:
return tensor
# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.
# :param tensor: Tensor to be scattered
# :param dim: The dimension scattering in
# :param parallel_mode: Parallel group mode used in this communication
# :type tensor: Tensor
# :type dim: int
# :type parallel_mode: ParallelMode
# :return: The tensor generated by scatter
# :rtype: Tensor
# """
# depth = gpc.get_world_size(parallel_mode)
# temp = tensor.clone()
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
# return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
else:
return tensor
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
work = None
else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
else:
return tensor

View File

@ -497,8 +497,7 @@ class ParallelContext:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.",
ranks=[0])
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(

View File

@ -184,8 +184,6 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
def launch_from_torch(config: Union[str, Path, Config, Dict],
host: str,
port: int,
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
@ -206,6 +204,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
launch(config=config,
local_rank=local_rank,
rank=rank,

View File

@ -1,5 +1,6 @@
from .layer import *
from .loss import *
from .lr_scheduler import *
from .metric import *
from .model import *
from .optimizer import *

View File

@ -1,33 +1,140 @@
import math
import warnings
from torch import Tensor
from torch.nn import init as init
import torch.nn as nn
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
if init_method == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(tensor, -a, a)
elif init_method == 'jax_embed':
def zeros_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor)
return initializer
def ones_():
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor)
return initializer
def uniform_(a: float = 0., b: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.uniform_(tensor, a, b)
return initializer
def normal_(mean: float = 0., std: float = 1.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std)
return initializer
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b)
return initializer
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
bound = math.sqrt(3.) * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
return tensor
if mode == 'fan_in':
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
elif mode == 'fan_out':
assert fan_out is not None, 'Fan_out is not provided.'
fan = fan_out
else:
raise ValueError(f'Invalid initialization mode \'{mode}\'')
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
return nn.init.normal_(tensor, 0, std)
return initializer
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
bound = a * std
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def xavier_normal_(scale: float = 2., gain: float = 1.):
# adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
fan = fan_in
if fan_out is not None:
fan += fan_out
std = gain * math.sqrt(scale / float(fan))
return nn.init.normal_(tensor, 0., std)
return initializer
def lecun_uniform_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
var = 1.0 / fan_in
bound = math.sqrt(3 * var)
return nn.init.uniform_(tensor, -bound, bound)
return initializer
def lecun_normal_():
# adapted from jax.nn.initializers
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.'
std = math.sqrt(1.0 / fan_in)
init.trunc_normal_(tensor, std=std / .87962566103423978)
elif init_method == 'zero':
init.zeros_(tensor)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
if init_method == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(tensor, -bound, bound)
elif init_method == 'jax':
init.normal_(tensor, std=1e-6)
elif init_method == 'jax_embed':
init.trunc_normal_(tensor, std=.02)
elif init_method == 'zero':
init.zeros_(tensor)
return initializer

View File

@ -1,8 +1,3 @@
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .non_parallel_layers import *
from .wrapper import *

View File

@ -1,11 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import collections.abc
from itertools import repeat
import numpy as np
from colossalai.utils.common import print_rank_0
import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.utils import checkpoint
@ -19,8 +18,7 @@ class CheckpointModule(nn.Module):
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError(
'CheckpointModule should implement _forward method instead of origin forward')
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
@ -36,6 +34,7 @@ class CheckpointModule(nn.Module):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
""" only allow exact division """
assert numerator % denominator == 0, \
@ -59,7 +58,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):

View File

@ -1,138 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_
def _split(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# all gather
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.mode = parallel_mode
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, parallel_mode):
return _reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _gather(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, parallel_mode):
return _ReduceGrad.apply(input_, parallel_mode)
def reduce_input(input_, parallel_mode):
return _ReduceInput.apply(input_, parallel_mode)
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)

View File

@ -0,0 +1,231 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init
from .vanilla import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_patchembedding = {
None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@ -1,8 +0,0 @@
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
__all__ = [
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
]

View File

@ -1,301 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
from .._common_utils import to_2tuple
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: size of a patch
:type patch_size: int
:param in_chans: input channels
:type in_chans: int
:param embed_dim: embedding dimension
:type embed_dim: int
:param norm_layer: layer norm class, defaults to None
:type norm_layer: Callable
:param flattern: whether flatten the output
:type flatten: bool
:param drop: dropout rate
:type drop: float
"""
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None, flatten=True, drop=0.):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class VanillaViTMLP(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
:param in_features: input channels
:type in_features: int
:param hidden_features: channels of the output of the first dense layer
:type hidden_features: int
:param hidden_features: channels of the output of the second dense layer
:type hidden_features: int
:param act_layer: activation function
:type act_layer: Callable
:param drop: dropout rate
:type drop: float
"""
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.GELU, drop=0.):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
:param drop_prob: probability for dropout
:type drop_prob: float
:param training: whether it is training mode
:type training: bool
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
@LAYERS.register_module
class VanillaViTDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
:param drop_prob: probability for dropout
:type drop_path: float
"""
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@LAYERS.register_module
class VanillaViTAttention(nn.Module):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def __init__(self, dim, num_heads, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@LAYERS.register_module
class VanillaViTBlock(nn.Module):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LAYERS.get_module('VanillaViTAttention')(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = LAYERS.get_module('VanillaViTDropPath')(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTHead(nn.Module):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features,
intermediate_features,
out_features,
bias=True
):
super().__init__()
self.linear_1 = nn.Linear(
in_features, intermediate_features, bias=bias)
self.act = nn.Tanh()
self.linear_2 = nn.Linear(
intermediate_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0, :].squeeze(1)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x

View File

@ -1,11 +1,4 @@
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
__all__ = [
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
]
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from .._common_utils import divide, ACT2FN
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from .layers import Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
@LAYERS.register_module
class TransformerMLP1D(ParallelLayer):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super(TransformerMLP1D, self).__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
bias=not skip_bias_add,
dtype=dtype,
gather_output = False,
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
bias=not skip_bias_add,
dtype=dtype,
parallel_input = True,
)
self.dropout = nn.Dropout(dropout_prob)
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
def forward(self, x):
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention1D(ParallelLayer):
"""Self attention layer for 1D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
# need to re-enable torch grad to enable fused optimization.
# self.layernorm = LayerNorm1D(
# hidden_size,
# dtype=dtype)
self.layernorm = nn.LayerNorm(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer1D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

View File

@ -1,6 +1,11 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from .._common_utils import divide
@ -15,4 +20,128 @@ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_
def _split(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# Split along last dimension.
dim_size = input_.size(dim)
assert dim_size % world_size == 0, \
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
f'cannot split tensor evenly'
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
return output
def _gather(input_, parallel_mode, dim=-1):
# skip if only one rank involved
world_size = gpc.get_world_size(parallel_mode)
if world_size == 1:
return input_
# all gather
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.mode = parallel_mode
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, ctx.mode), None
class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_, parallel_mode):
return _reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _split(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.mode, ctx.dim), None, None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_, parallel_mode, dim):
ctx.mode = parallel_mode
ctx.dim = dim
return _gather(input_, parallel_mode, dim)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.mode, ctx.dim), None, None
def reduce_grad(input_, parallel_mode):
return _ReduceGrad.apply(input_, parallel_mode)
def reduce_input(input_, parallel_mode):
return _ReduceInput.apply(input_, parallel_mode)
def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)

View File

@ -1,411 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from colossalai import context
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from .layers import Linear1D_Col, Linear1D_Row
from ..base_layer import ParallelLayer
from .._common_utils import to_2tuple
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP1D(ParallelLayer):
"""MLP layer for 1D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
weight_init='torch'
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
init_weight=weight_init,
init_bias=weight_init
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention1D(ParallelLayer):
"""Self-attention layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_bias = 'zero'
else:
init_bias = weight_init
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
init_weight=weight_init, init_bias=init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead1D(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
init_weight = 'zero'
init_bias = 'zero'
else:
init_weight = weight_init
init_bias = weight_init
self.linear = Linear1D_Col(
hidden_size,
num_classes,
dtype=dtype,
gather_output=True,
init_weight=init_weight,
init_bias=init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTHead(ParallelLayer):
"""Output layer for 1D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
):
super().__init__()
self.linear = nn.Linear(
hidden_size,
num_classes,
dtype=dtype
)
self._broadcast_linear_params()
def _broadcast_linear_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.linear.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.linear.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding1D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size
)
if weight_init == 'jax':
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
# sync
self._broadcast_conv_params()
def _broadcast_conv_params(self) -> None:
self.to(get_current_device())
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
dist.broadcast(self.proj.weight, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
dist.broadcast(self.proj.bias, src=ranks[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTTokenFuser1D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.empty(
1, self.num_patches + 1, self.embed_dim))
nn.init.trunc_normal_(self.pos_embed, std=.02)
# move to cuda before broadcast
self.to(get_current_device())
dist.broadcast(self.pos_embed,
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
group=gpc.get_group(ParallelMode.TENSOR))
self.pos_drop = nn.Dropout(p=drop_rate)
def forward(self, x: Tensor) -> Tensor:
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x.contiguous()

View File

@ -3,25 +3,24 @@
import math
import numbers
from typing import Callable, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import Tensor
from torch.nn.parameter import Parameter
from typing import Tuple
import importlib
from colossalai.context import seed, ParallelMode
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import FusedLayerNormAffineFunction1D
from torch import Tensor
from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
split_forward_gather_backward
from ..base_layer import ParallelLayer
from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
@LAYERS.register_module
@ -44,79 +43,46 @@ class Linear1D_Col(ParallelLayer):
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
"""
def __init__(self,
in_features: int,
output_size: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
self.in_features = in_features
self.out_features = output_size
self.out_features = out_features
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.in_features,
**factory_kwargs))
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
**factory_kwargs))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
@ -133,8 +99,7 @@ class Linear1D_Col(ParallelLayer):
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
if self.skip_bias_add:
@ -158,17 +123,15 @@ class Linear1D_Row(ParallelLayer):
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_input: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
# Keep input parameters
@ -186,58 +149,22 @@ class Linear1D_Row(ParallelLayer):
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.out_features,
self.input_size_per_partition,
**factory_kwargs))
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(
self.out_features,
**factory_kwargs
))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
dist.broadcast(self.bias,
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
group=gpc.get_group(ParallelMode.PARALLEL_1D))
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
@ -248,8 +175,7 @@ class Linear1D_Row(ParallelLayer):
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(
input_, ParallelMode.PARALLEL_1D, dim=-1)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
@ -263,12 +189,13 @@ class Linear1D_Row(ParallelLayer):
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
""" Experimental
"""
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = (normalized_shape, )
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
@ -280,5 +207,4 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)

View File

@ -1,11 +1,6 @@
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D
from .layers import Linear2D, LayerNorm2D
from ._operation import reduce_by_batch_2d, split_batch_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d',
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D',
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D',
'Linear2D', 'LayerNorm2D'
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
]

View File

@ -1,24 +1,25 @@
from typing import Any, Tuple
from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
def matmul_2d(a,
b,
summa_dim,
out_shape,
row_rank=None,
col_rank=None,
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
def matmul_2d(
a,
b,
summa_dim,
out_shape,
row_rank=None,
col_rank=None,
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
):
"""Matrix multiplication for 2D parallelism
:param a: matrix :math:`A`
:type a: torch.tensor
@ -44,16 +45,87 @@ def matmul_2d(a,
if col_rank is None:
col_rank = gpc.get_local_rank(row_parallel_mode)
data_parallel_rank = 0 if not gpc.is_initialized(
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = summa_dim ** 2
tensor_parallel_size = summa_dim**2
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size
)
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
class classifier_2d(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
bias: Optional[Tensor],
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
B_temp = all_gather(B, -1, col_parallel_mode)
if ctx:
ctx.save_for_backward(A, B_temp)
C = torch.matmul(A, B_temp.transpose(0, 1))
C = all_reduce(C, row_parallel_mode)
ctx.use_bias = bias is not None
if bias is not None:
C = C + bias
out = C.reshape(out_shape)
if ctx:
ctx.summa_dim = summa_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = torch.matmul(output_grad, B)
A_grad = A_grad.reshape(ctx.A_shape)
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_AB_2D(torch.autograd.Function):
@ -61,19 +133,21 @@ class Matmul_AB_2D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
# B: [h / q, s / q]
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
@ -116,15 +190,9 @@ class Matmul_AB_2D(torch.autograd.Function):
for i in range(summa_dim):
if i != summa_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
src=src_a + 1,
group=row_group,
async_op=True)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + summa_dim,
group=col_group,
async_op=True)
opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True)
if opa[cur] is not None:
opa[cur].wait()
@ -157,28 +225,14 @@ class Matmul_AB_2D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -187,20 +241,21 @@ class Matmul_ABT_2D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
@ -238,10 +293,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
for i in range(summa_dim):
if i != summa_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + summa_dim,
group=col_group,
async_op=True)
opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True)
if opr[cur] is not None:
opr[cur].wait()
@ -287,28 +339,14 @@ class Matmul_ABT_2D(torch.autograd.Function):
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_2D.apply(
output_grad, B,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2D.apply(
output_grad, A,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
@ -317,20 +355,21 @@ class Matmul_ATB_2D(torch.autograd.Function):
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
summa_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
@ -368,10 +407,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
for i in range(summa_dim):
if i != summa_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
src=src_a + 1,
group=row_group,
async_op=True)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
if opr[cur] is not None:
opr[cur].wait()
@ -417,61 +453,38 @@ class Matmul_ATB_2D(torch.autograd.Function):
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2D.apply(
B, output_grad,
ctx.summa_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2D.apply(
A, output_grad,
ctx.summa_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2D(torch.autograd.Function):
class add_bias_2d(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank,
group=gpc.get_group(col_parallel_mode))
def forward(
ctx: Any,
input_: Tensor,
bias: Tensor,
output_size_per_partition: int,
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
bias_temp = all_gather(bias, -1, col_parallel_mode)
ctx.row_rank = row_rank
ctx.col_rank = col_rank
@ -486,62 +499,33 @@ class Add_Bias_2D(torch.autograd.Function):
if skip_bias_add:
return bias_temp
else:
output = input + bias_temp
output = input_ + bias_temp
return output
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
row_rank = ctx.row_rank
col_rank = ctx.col_rank
row_parallel_mode = ctx.row_parallel_mode
col_parallel_mode = ctx.col_parallel_mode
data_parallel_rank = ctx.data_parallel_rank
pipeline_parallel_rank = ctx.pipeline_parallel_rank
pipeline_parallel_size = ctx.pipeline_parallel_size
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return None, output_grad, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
grad_tmp = torch.zeros_like(output_grad)
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None
grad = reduce_scatter(output_grad, -1, col_parallel_mode)
return None, grad, None, None, None, None, None, None, None, None, None, None
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None
else:
# for compatibility with zero optimizer, no grad should be None
reduce_tmp = torch.zeros_like(reduce)
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None
grad = reduce_scatter(reduce, -1, col_parallel_mode)
return output_grad, grad, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2D(torch.autograd.Function):
class layernorm_2d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
input_ = input_ - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.normalized_shape = hidden_size
output = input * Var_x
output = input_ * Var_x
ctx.save_for_backward(output, Var_x)
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
@ -555,14 +539,11 @@ class _LayerNorm_2D(torch.autograd.Function):
x, Var_x = ctx.saved_tensors
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=gpc.get_group(row_parallel_mode))
torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode))
output_grad_sum /= ctx.normalized_shape
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.normalized_shape
input_grad = output_grad.clone()
@ -573,69 +554,28 @@ class _LayerNorm_2D(torch.autograd.Function):
return input_grad, None, None, None, None, None
# class Sum_2D(torch.autograd.Function):
#
# @staticmethod
# def forward(ctx: Any,
# inputs: Tensor,
# dim: int,
# summa_dim: int,
# row_parallel_mode: ParallelMode,
# keepdim: bool = False) -> Tensor:
# # input: [b/q, s, h/q]
# empty_cache()
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
# return out
#
# @staticmethod
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
class AllGatherLast(torch.autograd.Function):
class all_gather_weight_2d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim
ctx.summa_dim = summa_dim
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = summa_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(summa_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
outputs = all_gather(inputs, dim, col_parallel_mode)
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
grad = output_grad.chunk(ctx.summa_dim, dim=ctx.dim)[ctx.row_rank]
return grad.contiguous(), None, None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
summa_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
@ -647,12 +587,33 @@ class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode))
return grad, None, None
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from .layers import Linear2D, LayerNorm2D
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=self.skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2D(ParallelLayer):
"""Self attention layer for 2D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

View File

@ -1,397 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.nn.layer._common_utils import divide, ACT2FN
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from colossalai.core import global_context as gpc
from ._operation import AllGatherLast, SplitFirst
from .layers import Linear2D
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
@LAYERS.register_module
class ViTMLP2D(ParallelLayer):
"""MLP layer for 2D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
# Project back to h.
self.dense_2 = Linear2D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init, init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2D(ParallelLayer):
"""Self-attention layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: whether to checkpoint the layer, defaults to False
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init, init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2D(ParallelLayer):
"""Output layer for 2D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'):
super().__init__()
assert_summa_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.summa_dim = get_summa_dim_from_env()
self.linear = Linear2D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight, init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // (self.summa_dim ** 2)
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTInputSplitter2D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = SplitFirst.apply(
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = AllGatherLast.apply(
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x)
return x

View File

@ -1,18 +1,22 @@
import math
from typing import Callable
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
import torch.nn as nn
import torch.nn.functional as F
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
from ._utils import get_summa_dim_from_env, assert_summa_initialization
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
from ._utils import assert_summa_initialization, get_summa_dim_from_env
@LAYERS.register_module
@ -30,15 +34,14 @@ class Linear2D(ParallelLayer):
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
@ -52,118 +55,57 @@ class Linear2D(ParallelLayer):
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(
self.in_features, self.summa_dim)
self.hidden_size_per_partition = divide(
self.out_features, self.summa_dim)
self.input_size_per_partition = divide(self.in_features, self.summa_dim)
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_AB_2D.apply(
x,
self.weight,
self.summa_dim,
out_shape,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size)
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
bias = add_bias_2d.apply(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_size, self.tensor_parallel_size)
return output, bias
else:
output = Add_Bias_2D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = add_bias_2d.apply(output, self.bias, self.hidden_size_per_partition, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
False, self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_size, self.tensor_parallel_size)
return output
else:
return output
@ -183,12 +125,7 @@ class LayerNorm2D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
super().__init__()
# layer norm config
@ -202,63 +139,252 @@ class LayerNorm2D(ParallelLayer):
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.summa_dim)
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
set_tensor_parallel_attribute_by_partition(self.gamma, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.beta, self.summa_dim**2)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL)
bias = Add_Bias_2D.apply(
None, self.beta, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2D.apply(
None, self.gamma, self.partitioned_partition,
self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = layernorm_2d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d.apply(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
scale = add_bias_2d.apply(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
return output
@LAYERS.register_module
class PatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.summa_dim**2)
with seed(ParallelMode.TENSOR):
self.weight = Parameter(
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = Parameter(
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_out = self.embed_size
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2d.apply(self.cls_token, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_weight_2d.apply(self.pos_embed, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed
return output
@LAYERS.register_module
class Embedding2D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@LAYERS.register_module
class Classifier2D(ParallelLayer):
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.summa_dim**2)
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0]
row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2d.apply(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)

View File

@ -1,12 +1,7 @@
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
from .layers import Linear2p5D, LayerNorm2p5D
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
__all__ = [
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
'ViTInputSplitter2p5D',
'Linear2p5D', 'LayerNorm2p5D'
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D'
]

View File

@ -2,11 +2,11 @@ from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
@ -22,25 +22,92 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode)
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class classifier_2p5d(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
A_shape = A.shape
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1]))
B_temp = all_gather(B, -1, col_parallel_mode)
if ctx:
ctx.save_for_backward(A, B_temp)
C = torch.matmul(A, B_temp.transpose(0, 1))
C = all_reduce(C, row_parallel_mode)
ctx.use_bias = bias is not None
if bias is not None:
C = C + bias
out = C.reshape(out_shape)
if ctx:
ctx.tesseract_dim = tesseract_dim
ctx.row_rank = row_rank
ctx.col_rank = col_rank
ctx.row_parallel_mode = row_parallel_mode
ctx.col_parallel_mode = col_parallel_mode
ctx.A_shape = A_shape
ctx.B_shape = B_shape
ctx.data_parallel_rank = data_parallel_rank
ctx.pipeline_parallel_rank = pipeline_parallel_rank
ctx.pipeline_parallel_size = pipeline_parallel_size
ctx.tensor_parallel_size = tensor_parallel_size
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = torch.matmul(output_grad, B)
A_grad = A_grad.reshape(ctx.A_shape)
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
class Matmul_AB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
# B: [h / dq, s / q]
@ -59,8 +126,8 @@ class Matmul_AB_2p5D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
@ -100,52 +167,26 @@ class Matmul_AB_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(
output_grad, B,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.apply(
A, output_grad,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ABT_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
@ -197,50 +238,25 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_2p5D.apply(
output_grad, B,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_ATB_2p5D.apply(
output_grad, A,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Matmul_ATB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int):
assert A.shape[-2] == B.shape[-2], \
@ -261,14 +277,12 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=get_parallel_group(row_parallel_mode))
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=get_parallel_group(col_parallel_mode))
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
@ -295,59 +309,30 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(
B, output_grad,
ctx.tesseract_dim, ctx.A_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
B_grad = Matmul_AB_2p5D.apply(
A, output_grad,
ctx.tesseract_dim, ctx.B_shape,
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size
)
A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2p5D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
row_rank: int,
col_rank: int,
dep_rank: int,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int
) -> Tensor:
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(
output_size_per_partition,
dtype=bias.dtype,
device=get_current_device())
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
@ -407,14 +392,10 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class _LayerNorm_2p5D(torch.autograd.Function):
class layernorm_2p5d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any,
input: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
row_parallel_mode: ParallelMode) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
@ -432,14 +413,11 @@ class _LayerNorm_2p5D(torch.autograd.Function):
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
with torch.no_grad():
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_sum, group=get_parallel_group(row_parallel_mode))
torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode))
output_grad_sum /= ctx.hidden_size
output_grad_mul_x_sum = torch.sum(
output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True)
torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
output_grad_mul_x_sum /= ctx.hidden_size
input_grad = output_grad.clone()
@ -450,105 +428,28 @@ class _LayerNorm_2p5D(torch.autograd.Function):
return input_grad, None, None, None, None, None, None
# class Sum_2p5D(torch.autograd.Function):
# """Compute the sum of input tensors
# """
# @staticmethod
# def forward(ctx,
# inputs,
# dim,
# tesseract_dim,
# row_parallel_mode,
# keepdim=False):
# # input: [b/q, s, h/q]
# ctx.save_for_backward(inputs)
# # sum: [b/q, s]
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
# torch.distributed.all_reduce(
# out, group=gpc.get_group(row_parallel_mode))
# return out
# @staticmethod
# def backward(ctx, output_grad):
# with torch.no_grad():
# inputs = ctx.saved_tensors
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
# return input_grad, None, None, None, None, None
# class _ViT_Split_2p5D(torch.autograd.Function):
# @staticmethod
# @custom_fwd(cast_inputs=torch.float16)
# def forward(ctx, inputs, batch_size,
# tesseract_dim, tesseract_dep,
# xz_parallel_mode):
# # inputs: [b, s, h/q]
# # output: [b/dq, s, h/q]
# ctx.BATCH_SIZE = batch_size
# ctx.tesseract_dim = tesseract_dim
# ctx.tesseract_dep = tesseract_dep
# ctx.xz_parallel_mode = xz_parallel_mode
# xz_rank = gpc.get_local_rank(xz_parallel_mode)
# output = torch.chunk(inputs, tesseract_dep *
# tesseract_dim, dim=0)[xz_rank]
# output = output.clone()
# return output
# @staticmethod
# @custom_bwd
# def backward(ctx, output_grad):
# # output_grad: [b/dq, s, h/q]
# # grads: [b, s, h/q]
# # *
# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
# grads = torch.empty(grads_shape,
# dtype=output_grad.dtype,
# device=get_current_device())
# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
# output_grad.contiguous(),
# group=get_parallel_group(ctx.xz_parallel_mode))
# return grads, None, None, None, None
class AllGatherLast(torch.autograd.Function):
class all_gather_weight_2p5d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
tesseract_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim
ctx.tesseract_dim = tesseract_dim
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
last_dim = tesseract_dim * inputs.size(-1)
outputs_shape = (last_dim,) + inputs.shape[:-1]
outputs = torch.empty(
outputs_shape, dtype=inputs.dtype, device=get_current_device())
dist.all_gather(
list(outputs.chunk(tesseract_dim, dim=0)),
inputs.permute(2, 0, 1).contiguous(),
group=gpc.get_group(col_parallel_mode)
)
outputs = outputs.permute(1, 2, 0).contiguous()
outputs = all_gather(inputs, dim, col_parallel_mode)
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
return grad.contiguous(), None, None
grad = output_grad.chunk(ctx.tesseract_dim, dim=ctx.dim)[ctx.row_rank]
return grad.contiguous(), None, None, None
class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
inputs: Tensor,
tesseract_dim: int,
col_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.tesseract_dim = tesseract_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
@ -560,12 +461,33 @@ class SplitFirst(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(
grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode)
)
return grad, None, None
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode))
return grad, None, None
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output

View File

@ -1,220 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor
from colossalai.nn.layer._common_utils import divide
from colossalai.registry import LAYERS
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D, LayerNorm2p5D
from .._common_utils import ACT2FN
from ..base_layer import ParallelLayer
@LAYERS.register_module
class TransformerMLP2p5D(ParallelLayer):
"""
MLP will take the input with h hidden state, project it to mlp_ratio * h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
:param in_features: the size of input tensor
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: int, optional
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int = 4.0,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
skip_bias_add: bool = False
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.in_features = in_features
self.skip_bias_add = skip_bias_add
# Project to h * mlp_ratio.
self.dense_1 = Linear2p5D(
in_features,
int(mlp_ratio * in_features),
dtype=dtype,
skip_bias_add=skip_bias_add
)
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
f'activation function can only be {list(ACT2FN.keys())}'
self.activation_func = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
int(mlp_ratio * in_features),
in_features,
dtype=dtype,
skip_bias_add=skip_bias_add
)
self.dropout = nn.Dropout(dropout_prob)
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
if self.skip_bias_add:
intermediate_output, _ = self.dense_1(x)
else:
intermediate_output = self.dense_1(x)
intermediate_output = self.activation_func(intermediate_output)
if self.skip_bias_add:
output, _ = self.dense_2(intermediate_output)
else:
output = self.dense_2(intermediate_output)
output = self.dropout(output)
output = self.layernorm(x + output)
return output
@LAYERS.register_module
class TransformerSelfAttention2p5D(ParallelLayer):
"""Self attention layer for 2.5D parallel Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layer
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layer
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.layernorm = LayerNorm2p5D(
hidden_size,
dtype=dtype)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
attention_output = self.layernorm(hidden_states + output)
return attention_output
@LAYERS.register_module
class TransformerLayer2p5D(ParallelLayer):
"""Transformer layer which contains a self-attention layer and a MLP layer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
:type mlp_ratio: float, optional
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
:type attention_dropout_prob: float, optional
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
:type hidden_dropout_prob: float, optional
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
):
super().__init__()
self.attention = TransformerSelfAttention2p5D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
)
self.mlp = TransformerMLP2p5D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
)
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
attention_output = self.attention(hidden_states, attention_mask)
output = self.mlp(attention_output)
return output

View File

@ -1,421 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import nn as nn, Tensor, distributed as dist
from torch.nn.init import _calculate_fan_in_and_fan_out
from colossalai.context import seed, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.utils import checkpoint
from colossalai.utils import get_current_device
from ._operation import AllGatherLast, SplitFirst
from ._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from .layers import Linear2p5D
from ..base_layer import ParallelLayer
from ..fused_bias_gelu import bias_gelu_impl
from .._common_utils import (ACT2FN, divide, to_2tuple,
set_tensor_parallel_attribute_by_partition)
@LAYERS.register_module
class ViTMLP2p5D(ParallelLayer):
"""MLP layer for 2.5D parallel Vision Transformer
:param in_features: size of each input sample
:type in_features: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param act_func: activation function, defaults to 'gelu'
:type act_func: str, optional
:param dropout_prob: dropout probability, defaults to 0.
:type dropout_prob: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
assert_tesseract_initialization()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if act_func == 'fused_gelu':
self.act = bias_gelu_impl
skip_dense_1_add_bias = True
else:
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear2p5D(
self.in_features,
self.mlp_ratio * self.in_features,
dtype=dtype,
init_weight=weight_init,
init_bias=weight_init,
skip_bias_add=skip_dense_1_add_bias
)
self.act = ACT2FN[act_func]
# Project back to h.
self.dense_2 = Linear2p5D(
self.mlp_ratio * self.in_features,
self.in_features,
dtype=dtype,
init_weight=weight_init,
init_bias=weight_init
)
self.dropout = nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
if self.act == bias_gelu_impl:
intermediate_output, bias = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output, bias)
else:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
with seed(ParallelMode.TENSOR):
intermediate_output = self.dropout(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTSelfAttention2p5D(ParallelLayer):
"""Self-attention layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout_prob: dropout probability for attention layers
:type attention_dropout_prob: float
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
:type checkpoint: bool, optional
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=None,
checkpoint: bool = False,
weight_init='torch'
):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.hidden_size = hidden_size
self.num_attention_heads = divide(
num_attention_heads, self.tesseract_dim) # *
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_bias = 'zero'
else:
self.init_bias = weight_init
self.query_key_value = Linear2p5D(
hidden_size,
3 * hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
)
self.attention_dropout = nn.Dropout(attention_dropout_prob)
self.dense = Linear2p5D(
hidden_size,
hidden_size,
dtype=dtype,
init_weight=weight_init,
init_bias=self.init_bias
)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(
query_key_value, 3, dim=-1)
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / \
math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[
:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead2p5D(ParallelLayer):
"""Output layer for 2.5D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_classes: number of classes
:type num_classes: int
:param dtype: dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
hidden_size,
num_classes,
dtype=None,
weight_init='torch'
):
super().__init__()
assert_tesseract_initialization()
assert weight_init in ('torch', 'jax')
if weight_init == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
else:
self.init_weight = weight_init
self.init_bias = weight_init
self.linear = Linear2p5D(
hidden_size,
num_classes,
dtype=dtype,
init_weight=self.init_weight,
init_bias=self.init_bias
)
def forward(self, x: Tensor) -> Tensor:
x = x[:, 0]
x = self.linear(x)
return x
@LAYERS.register_module
class ViTPatchEmbedding2p5D(ParallelLayer):
""" 2.5D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
in_chans=3,
flatten=True,
weight_init='torch'):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
with seed(ParallelMode.TENSOR):
self.proj = nn.Conv2d(in_chans,
self.embed_dim,
kernel_size=patch_size,
stride=patch_size,
device=get_current_device()
)
self._set_tensor_parallel_attribute()
if weight_init == 'jax':
with seed(ParallelMode.TENSOR):
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
std = math.sqrt(1.0 / fan_in)
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
nn.init.zeros_(self.proj.bias)
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
def forward(self, x: Tensor) -> Tensor:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
@LAYERS.register_module
class ViTInputSplitter2p5D(ParallelLayer):
"""Split the input tensor for 2D parallel Vision Transformer
"""
def __init__(self):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
def forward(self, x: Tensor) -> Tensor:
x = AllGatherLast.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = SplitFirst.apply(
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
return x
@LAYERS.register_module
class ViTTokenFuser2p5D(ParallelLayer):
"""
Fuse cls token and pos embedding to the input
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param drop_rate: dropout probability, defaults to 0.
:type drop_rate: float, optional
"""
def __init__(self,
img_size,
patch_size,
embed_dim,
drop_rate=0.
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_dim = embed_dim
self.cls_token = nn.Parameter(torch.zeros(
(1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
device=get_current_device()))
self.pos_embed = nn.Parameter(torch.empty(
(1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
device=get_current_device()))
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
def _broadcast_params(self, param) -> None:
" broadcast to all column ranks for data consistency "
if self.tesseract_dep > 1:
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
dist.broadcast(param, src=xz_rank[0],
group=xz_group)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(
ParallelMode.PARALLEL_2P5D_XZ))
grad = grad / self.tesseract_dim # / self.tesseract_dep # *
return grad
def forward(self, x: Tensor) -> Tensor:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = AllGatherLast.apply(
self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
pos_embed = AllGatherLast.apply(
self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
x = x + pos_embed
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x)
return x

View File

@ -1,17 +1,23 @@
import math
from typing import Callable
import torch
from torch import Tensor
from torch.nn import Parameter, init as init
from colossalai.context import seed, ParallelMode
import torch.nn as nn
import torch.nn.functional as F
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
split_batch_2p5d)
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
@LAYERS.register_module
@ -27,16 +33,14 @@ class Linear2p5D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype=None,
dtype: dtype = None,
skip_bias_add: bool = False,
init_weight='torch',
init_bias='torch'
):
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
@ -52,76 +56,48 @@ class Linear2p5D(ParallelLayer):
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(
out_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(
self.input_size_per_partition,
self.hidden_size_per_partition,
**factory_kwargs))
self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(
self.hidden_size_per_partition,
**factory_kwargs))
self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs))
else:
self.register_parameter('bias', None)
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def reset_parameters(self, init_weight, init_bias) -> None:
assert init_weight in ('torch', 'jax', 'zero')
assert init_bias in ('torch', 'jax', 'zero')
# setting
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
# init weight
if init_weight == 'torch':
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
init.uniform_(self.weight, -bound, bound)
elif init_weight == 'jax':
std = math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std
init.uniform_(self.weight, -a, a)
elif init_weight == 'zero':
init.zeros_(self.weight)
# init bias
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
if init_bias == 'torch':
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
elif init_bias == 'jax':
init.normal_(self.bias, std=1e-6)
elif init_bias == 'zero':
init.zeros_(self.bias)
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_AB_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
out_shape,
self.row_rank, self.col_rank, self.dep_rank,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank,
@ -132,34 +108,17 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None:
if self.skip_bias_add:
bias = Add_Bias_2p5D.apply(
None,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
bias = Add_Bias_2p5D.apply(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL,
True, self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_size, self.tensor_parallel_size)
return output, bias
else:
output = Add_Bias_2p5D.apply(
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = Add_Bias_2p5D.apply(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL, False, self.data_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return output
else:
return output
@ -179,12 +138,7 @@ class LayerNorm2p5D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self,
normalized_shape: int,
eps: float = 1e-05,
dtype=None
):
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
super().__init__()
# layer norm config
@ -199,66 +153,251 @@ class LayerNorm2p5D(ParallelLayer):
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(
normalized_shape, self.tesseract_dim) # *
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.gamma = Parameter(torch.ones(
self.partitioned_partition,
**factory_kwargs))
self.beta = Parameter(torch.zeros(
self.partitioned_partition,
**factory_kwargs))
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
set_tensor_parallel_attribute_by_partition(self.gamma, self.tesseract_dim)
set_tensor_parallel_attribute_by_partition(self.beta, self.tesseract_dim)
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(
None, self.beta, self.partitioned_partition,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
scale = Add_Bias_2p5D.apply(
None, self.gamma, self.partitioned_partition,
self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size
)
output = layernorm_2p5d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
scale = Add_Bias_2p5D.apply(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output)
return output
@LAYERS.register_module
class PatchEmbedding2p5D(ParallelLayer):
""" 2D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param embed_dim: dimension of embedding
:type embed_dim: int
:param in_chans: number of channels of input image, defaults to 3
:type in_chans: int, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
with seed(ParallelMode.TENSOR):
self.weight = Parameter(
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = Parameter(
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_out = self.embed_size
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2p5d.apply(self.cls_token, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_weight_2p5d.apply(self.pos_embed, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed
return output
@LAYERS.register_module
class Embedding2p5D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@LAYERS.register_module
class Classifier2p5D(ParallelLayer):
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0]
row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2p5d.apply(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)

View File

@ -1,9 +1,6 @@
from ._operation import Matmul_ABT_3D, Matmul_ATB_3D, Matmul_AB_3D, Mul_3D, Sum_3D, Add_3D, Reduce_3D
from ._vit import ViTHead3D, ViTMLP3D, ViTPatchEmbedding3D, ViTSelfAttention3D
from .layers import Linear3D, LayerNorm3D
from ._operation import reduce_by_batch_3d, split_batch_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
__all__ = [
'Matmul_ABT_3D', 'Matmul_ATB_3D', 'Matmul_AB_3D', 'Mul_3D', 'Sum_3D', 'Add_3D', 'Reduce_3D',
'ViTHead3D', 'ViTMLP3D', 'ViTPatchEmbedding3D', 'ViTSelfAttention3D',
'Linear3D', 'LayerNorm3D'
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
]

View File

@ -1,11 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Optional, Tuple
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.communication import all_gather, all_reduce, reduce_scatter, broadcast, reduce
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
@ -15,7 +14,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
class linear_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
def forward(ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
@ -25,33 +24,16 @@ class linear_3d(torch.autograd.Function):
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
assert input_.shape[-1] == weight.shape[0], \
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode)
input_ = torch.cat(input_, dim=input_dim)
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
if bias is not None:
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
# dist.broadcast(bias,
# src=src_rank,
# group=gpc.get_group(output_parallel_mode))
# bias = all_gather(bias, -1, weight_parallel_mode)
output += bias
# ctx.src_rank = src_rank
# ctx.save_for_backward(input_, weight)
# output = torch.matmul(input_, weight)
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
# output += bias
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
@ -63,115 +45,105 @@ class linear_3d(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
# dist.all_reduce(input_grad,
# group=gpc.get_group(ctx.input_parallel_mode))
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# dist.all_reduce(weight_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
# bias_grad = torch.sum(output_grad,
# dim=tuple(
# range(len(output_grad.shape))[:-1]))
# bias_grad = reduce_scatter(bias_grad, -1,
# ctx.weight_parallel_mode)
# dist.reduce(bias_grad,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.output_parallel_mode))
# if gpc.get_local_rank(
# ctx.output_parallel_mode) != gpc.get_local_rank(
# ctx.input_parallel_mode):
# bias_grad = None
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
# weight = all_gather(weight, ctx.weight_dim,
# ctx.weight_parallel_mode)
output_grad = all_gather(output_grad, ctx.output_dim,
ctx.output_parallel_mode)
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
async_ops = list()
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
ctx.input_parallel_mode,
async_op=True)
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
output_grad.reshape(-1, output_grad.shape[-1]))
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
# weight_grad = torch.matmul(
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
# output_grad.reshape(-1, output_grad.shape[-1]))
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
# ctx.weight_parallel_mode)
if ctx.use_bias:
bias_grad = torch.sum(output_grad,
dim=tuple(
range(len(output_grad.shape))[:-1]))
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
# dist.all_reduce(bias_grad,
# group=gpc.get_group(ctx.weight_parallel_mode))
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
input_op.wait()
weight_op.wait()
if ctx.use_bias:
bias_grad = weight_grad[-1]
weight_grad = weight_grad[:-1]
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layer_norm_3d(torch.autograd.Function):
class classifier_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
ctx.use_bias = bias is not None
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight.transpose(0, 1))
output = all_reduce(output, output_parallel_mode)
if bias is not None:
output += bias
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
async_ops = list()
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
weight_grad = None
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
input_grad = torch.matmul(output_grad, weight)
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layernorm_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# mean = torch.sum(input_, dim=-1)
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
# mean /= normalized_shape
# mu = input_ - mean
# var = torch.sum(torch.pow(mu, 2), dim=-1)
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
# var /= normalized_shape
# std_dev = torch.sqrt(var + eps)
# ctx.save_for_backward(input_, mu, std_dev, weight)
# output = weight * mu / std_dev + bias
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
output_parallel_mode) / normalized_shape
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
sigma = torch.sqrt(var + eps)
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# transforms = torch.stack([weight, bias]).contiguous()
# dist.broadcast(transforms,
# src=src_rank,
# group=gpc.get_group(input_parallel_mode))
# transforms = all_gather(transforms, -1, weight_parallel_mode)
# weight, bias = transforms[0], transforms[1]
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z + bias
# ctx.src_rank = src_rank
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
@ -181,7 +153,7 @@ class layer_norm_3d(torch.autograd.Function):
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
@ -191,373 +163,63 @@ class layer_norm_3d(torch.autograd.Function):
grads = all_reduce(grads, ctx.input_parallel_mode)
bias_grad, weight_grad = grads[0], grads[1]
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
# dist.reduce(grads,
# dst=ctx.src_rank,
# group=gpc.get_group(ctx.input_parallel_mode))
# if gpc.get_local_rank(
# ctx.input_parallel_mode) == gpc.get_local_rank(
# ctx.output_parallel_mode):
# bias_grad, weight_grad = grads[0], grads[1]
# else:
# bias_grad, weight_grad = None, None
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
input_grad = dz / sigma + dvar * 2 * mu / \
ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
class Matmul_AB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB`
"""
def split_batch_3d(input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
dim: int = 0) -> Tensor:
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
class reduce_by_batch_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, h/q]
ctx.save_for_backward(A, B)
assert A.shape[-1] == B.shape[0], \
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode)
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ABT_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
# A: [m/q^2, n, h/q]
# B: [k/q, h/q^2]
# C: [m/q^2, n, k/q]
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.C_dim,
ctx.B_dim, ctx.A_dim)
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.C_dim,
ctx.A_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Matmul_ATB_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
A: Tensor,
B: Tensor,
depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = 0,
output_dim: int = -1) -> Tensor:
# A: [m/q^2, n, k/q]
# B: [m/q^2, n, h/q]
# C: [k/q, h/q^2]
ctx.save_for_backward(A, B)
A_temp = all_gather(A, input_dim, input_parallel_mode)
A_temp = A_temp.reshape(-1, A.shape[-1])
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
B_temp = B_temp.reshape(-1, B.shape[-1])
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
out = reduce_scatter(C, output_dim, output_parallel_mode)
ctx.depth = depth
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
ctx.A_dim = input_dim
ctx.B_dim = weight_dim
ctx.C_dim = output_dim
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
ctx.B_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.A_group_parallel_mode, ctx.B_dim,
ctx.C_dim, ctx.A_dim)
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
ctx.A_group_parallel_mode,
ctx.C_group_parallel_mode,
ctx.B_group_parallel_mode, ctx.A_dim,
ctx.C_dim, ctx.B_dim)
return A_grad, B_grad, None, None, None, None, None, None, None
class Add_3D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
out = input_ + bias_temp
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
# [h/q]
grad = torch.sum(output_grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return output_grad, bias_grad, None, None, None, None
class Mul_3D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A * b`
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
# input: [m/q^2, n, h/q]
# bias: [h/q^2]
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
# [h/q^2]
bias_temp = bias.clone()
dist.broadcast(bias_temp,
src=src_rank,
group=gpc.get_group(input_parallel_mode))
# [h/q]
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
# empty_cache()
ctx.save_for_backward(input_, bias_temp)
out = torch.mul(input_, bias_temp)
ctx.depth = depth
ctx.src_rank = src_rank
ctx.A_group_parallel_mode = input_parallel_mode
ctx.B_group_parallel_mode = weight_parallel_mode
ctx.C_group_parallel_mode = output_parallel_mode
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# output_grad: [m/q^2, n, h/q]
with torch.no_grad():
input_, bias = ctx.saved_tensors
# [m/q^2, n, h/q]
input_grad = torch.mul(output_grad, bias)
# [h/q]
grad = torch.mul(output_grad, input_)
grad = torch.sum(grad,
dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
dist.reduce(bias_grad,
dst=ctx.src_rank,
group=gpc.get_group(ctx.A_group_parallel_mode))
if gpc.get_local_rank(
ctx.A_group_parallel_mode) != gpc.get_local_rank(
ctx.C_group_parallel_mode):
bias_grad = None
return input_grad, bias_grad, None, None, None, None
class Sum_3D(torch.autograd.Function):
"""Compute the sum of input tensors
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any,
input_: Tensor,
dim: int,
depth: int,
parallel_mode: ParallelMode,
keepdim: bool = False) -> Tensor:
# input: [m/q^2, n, h/q]
out = torch.sum(input_, dim=dim, keepdim=keepdim)
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
ctx.input_shape = input_.shape
ctx.depth = depth
ctx.group = parallel_mode
ctx.dim = dim
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
with torch.no_grad():
output_grad = output_grad.contiguous()
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
if len(output_grad.shape) < len(ctx.input_shape):
output_grad = torch.unsqueeze(output_grad, ctx.dim)
dims = [1 for _ in range(len(output_grad.shape))]
dims[ctx.dim] = ctx.input_shape[ctx.dim]
input_grad = output_grad.repeat(tuple(dims))
return input_grad, None, None, None, None, None
class Reduce_3D(torch.autograd.Function):
"""Reduce input tensors
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input_: Tensor, depth: int,
parallel_mode: ParallelMode) -> Tensor:
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
return input_.clone()
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
# class Slice_3D(torch.autograd.Function):
# """Slice input tensor
# """
# @staticmethod
# @custom_fwd(cast_inputs=torch.float16)
# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
# parallel_mode: ParallelMode) -> Tensor:
# rank = gpc.get_local_rank(parallel_mode)
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
output = broadcast(input_, src_rank, input_parallel_mode)
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
# ctx.depth = depth
# ctx.parallel_mode = parallel_mode
# ctx.dim = dim
# ctx.input_shape = input_.shape
# return out
# @staticmethod
# @custom_bwd
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
# with torch.no_grad():
# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
# input_grad.reshape(ctx.input_shape)
# return input_grad, None, None, None
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
else:
input_grad = None
return input_grad, None, None, None

View File

@ -1,413 +0,0 @@
import math
import os
from typing import Tuple, Optional
import torch
import torch.distributed as dist
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.registry import LAYERS
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.utils import checkpoint, get_current_device
from torch import Tensor, dtype, nn
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
from .layers import Linear3D
@LAYERS.register_module
class ViTPatchEmbedding3D(nn.Module):
""" 3D Image to Patch Embedding
:param img_size: iamge size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: dimension of embedding
:type embed_size: int
:param drop_prob: dropout probability
:type drop_prob: float
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
drop_prob: float,
flatten: bool = True,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.in_chans = in_chans
self.embed_size = embed_size
self.embed_size_per_partition = divide(self.embed_size, self.depth)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax_embed'
self.init_bias = 'zero'
self.proj = nn.Conv2d(self.in_chans,
self.embed_size_per_partition,
kernel_size=patch_size,
stride=patch_size)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.embed_size_per_partition))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1,
self.embed_size_per_partition))
self.pos_drop = nn.Dropout(drop_prob)
self.reset_parameters(self.init_weight, self.init_bias)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
def reset_parameters(self, init_weight, init_bias):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
# std = math.sqrt(1.0 / fan_in)
# nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
# nn.init.zeros_(self.proj.bias)
if init_weight != 'torch':
init_weight_(self.proj.weight, fan_in, init_method=init_weight)
init_bias_(self.pos_embed, fan_in, init_method=init_weight)
if init_bias != 'torch':
init_bias_(self.proj.bias, fan_in, init_method=init_bias)
self.to(get_current_device())
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.proj.bias,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
dist.broadcast(self.proj.weight,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
dist.broadcast(self.proj.bias,
src=input_src_rank,
group=gpc.get_group(self.input_parallel_mode))
self.proj.weight.register_hook(self._sync_grad_hook)
self.proj.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def _sync_grad_hook(self, grad) -> None:
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
dist.all_reduce(grad, group=gpc.get_group(self.weight_parallel_mode))
return grad
def forward(self, x: Tensor) -> Tensor:
# split a partition from inputs
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.weight_parallel_mode)].contiguous()
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
self.input_parallel_mode)].contiguous()
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# add cls token & pos embedding
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
with seed(ParallelMode.TENSOR):
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class ViTSelfAttention3D(nn.Module):
"""Self-attention layer for 3D parallel Vision Transformer
:param hidden_size: hidden size
:type hidden_size: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_probs_dropout_prob: dropout probability for attention layers
:type attention_probs_dropout_prob: bool
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: bool
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_probs_dropout_prob: float,
hidden_dropout_prob: float,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__()
self.depth = get_depth_from_env()
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size
self.num_attention_heads = divide(num_attention_heads, self.depth)
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.checkpoint = checkpoint
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'jax'
self.init_bias = 'zero'
self.query_key_value = Linear3D(self.hidden_size,
3 * self.hidden_size,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
self.dense = Linear3D(self.hidden_size,
self.hidden_size,
# self.output_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
# return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
3,
dim=-1)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
attention_probs = self.softmax(attention_scores)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTMLP3D(nn.Module):
"""[summary]
:param hidden_size: hidden size
:type hidden_size: int
:param mlp_ratio: hidden size of MLP divided by embedding dim
:type mlp_ratio: int
:param hidden_dropout_prob: dropout probability for hidden layers
:type hidden_dropout_prob: float
:param hidden_act: activation function for hidden layers
:type hidden_act: str
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
hidden_size: int,
mlp_ratio: int,
hidden_dropout_prob: float,
hidden_act: str = 'gelu',
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch'):
super().__init__()
# self.depth = get_depth_from_env()
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.hidden_size = hidden_size
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.init_weight = init_method
self.init_bias = init_method
self.dense_1 = Linear3D(self.hidden_size,
self.mlp_ratio * self.hidden_size,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.activation_func = ACT2FN[hidden_act]
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
self.hidden_size,
# self.output_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
self.dropout = nn.Dropout(hidden_dropout_prob)
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
# return self.input_parallel_mode, self.weight_parallel_mode
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.activation_func(intermediate_output)
output = self.dense_2(intermediate_output)
with seed(ParallelMode.TENSOR):
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
@LAYERS.register_module
class ViTHead3D(nn.Module):
"""Output layer for 3D parallel Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param depth: the 3D parallelism depth
:type depth: int
:param input_parallel_mode: parallel mode of input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode of weight
:type weight_parallel_mode: ParallelMode
:param dtype: dtype of parameters, defaults to None
:type dtype: dtype, optional
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
dtype: dtype = None,
bias: bool = True,
init_method: str = 'torch'):
super().__init__()
# self.depth = get_depth_from_env()
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
# self.weight_parallel_mode)
self.in_features = in_features
self.num_classes = num_classes
# out_features = math.ceil(self.num_classes /
# (self.depth**2)) * (self.depth**2)
# self.num_classes_per_partition = divide(self.num_classes, self.depth)
self.init_weight = 'torch'
self.init_bias = 'torch'
if init_method == 'jax':
self.init_weight = 'zero'
self.init_bias = 'zero'
self.linear = Linear3D(self.in_features,
self.num_classes,
# self.input_parallel_mode,
# self.weight_parallel_mode,
dtype=dtype,
bias=bias,
init_weight=self.init_weight,
init_bias=self.init_bias)
def forward(self, x: Tensor) -> Tensor:
# [b/q^2, s, h/q] --> [b/q^2, h/q]
x = x[:, 0]
# [b/q^2, h/q] --> [b/q^2, c/q]
x = self.linear(x)
# return x[:, :self.num_classes_per_partition]
return x
def extra_repr(self):
return 'in_features={}, num_classes={}'.format(self.in_features,
self.num_classes)

View File

@ -1,191 +1,311 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import os
from typing import Tuple
from typing import Callable
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
import torch.nn.functional as F
from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn.init import init_bias_, init_weight_
from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from torch.nn import init as init
from .._common_utils import divide, set_tensor_parallel_attribute_by_size
from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
linear_3d)
from ._utils import (get_depth_from_env, get_last_group,
get_parallel_mode_from_env, swap_in_out_group)
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ._operation import *
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
@LAYERS.register_module
class LayerNorm3D(nn.Module):
def __init__(
self,
normalized_shape: int,
# input_parallel_mode: ParallelMode,
# weight_parallel_mode: ParallelMode,
eps: float = 1e-12,
dtype: dtype = None,
):
class LayerNorm3D(ParallelLayer):
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(),
dtype=dtype))
self.variance_epsilon = eps
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self):
init.zeros_(self.bias)
init.ones_(self.weight)
def reset_parameters(self) -> None:
init.zeros_()(self.bias)
init.ones_()(self.weight)
def forward(self, input_: Tensor) -> Tensor:
# '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# # input: [m/q^2, n, h/q]
# # [m/q^2, n, 1]
# mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
# True) / self.normalized_shape
# # [m/q^2, n, 1]
# var = (input_ - mean).pow(2)
# var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
# True) / self.normalized_shape
# output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
# output = Mul_3D.apply(output, self.weight, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
# output = Add_3D.apply(output, self.bias, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
# return output
return layer_norm_3d.apply(input_, self.weight, self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape,
self.variance_epsilon)
return layernorm_3d.apply(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
@LAYERS.register_module
class Linear3D(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
# input_parallel_mode: ParallelMode,
# weight_parallel_mode: ParallelMode,
bias: bool = True,
dtype: dtype = None,
init_weight: str = 'torch',
init_bias: str = 'torch'):
class Linear3D(ParallelLayer):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
# self.with_bias = bias
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth)
# [k/q, h/q]
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
# [h/q]
if bias:
self.bias = Parameter(
torch.zeros(self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
dtype=dtype))
else:
self.register_parameter('bias', None)
self.bias = None
self.reset_parameters(init_weight, init_bias)
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
swap_in_out_group()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self, init_weight, init_bias) -> None:
# setting
fan_in, fan_out = self.in_features, self.out_features
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
# init weight
init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
dist.broadcast(self.weight,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
# init bias
if self.bias is not None:
init_bias_(self.bias, fan_in, init_method=init_bias)
dist.broadcast(self.bias,
src=weight_src_rank,
group=gpc.get_group(self.weight_parallel_mode))
dist.broadcast(self.bias,
src=output_src_rank,
group=gpc.get_group(self.output_parallel_mode))
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def forward(self, input_: Tensor) -> Tensor:
# # input: [m/q^2, n, k/q]
# # output: [m/q^2, n, h/q]
# output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
# self.input_parallel_mode,
# self.weight_parallel_mode,
# self.output_parallel_mode)
# if self.bias is not None:
# output = Add_3D.apply(output, self.bias, self.depth,
# self.output_parallel_mode,
# self.weight_parallel_mode,
# self.input_parallel_mode)
# return output
return linear_3d.apply(input_, self.weight, self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
return linear_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.with_bias)
@LAYERS.register_module
class Classifier3D(ParallelLayer):
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def forward(self, input_: Tensor) -> Tensor:
return classifier_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
@LAYERS.register_module
class PatchEmbedding3D(ParallelLayer):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.patch_size = to_2tuple(patch_size)
grid_size = to_2tuple(img_size // patch_size)
num_patches = grid_size[0] * grid_size[1]
self.embed_size = embed_size
embed_size_per_partition = divide(embed_size, self.depth)
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> None:
grad = all_reduce(grad, self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_out = self.embed_size
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
self.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class Embedding3D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, self.depth)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output

View File

@ -0,0 +1,3 @@
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath']

View File

@ -0,0 +1,134 @@
import math
from typing import Callable
import torch
import torch.nn.functional as F
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch import nn as nn
from .._common_utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer):
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)

View File

@ -1,5 +1,26 @@
from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D
from torch import nn
from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss
__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D']
from .loss_2d import CrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D
_parallel_cross_entropy = {
'2d': CrossEntropyLoss2D,
'2.5d': CrossEntropyLoss2p5D,
'3d': CrossEntropyLoss3D
}
class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
super().__init__()
if tensor_parallel in [None, '1d']:
reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else:
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
def forward(self, *args):
return self.loss(*args)

View File

@ -1,131 +0,0 @@
import torch
import torch.distributed as dist
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0, end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=get_current_device())
grad_2d[arange_1d,
masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
class _ReduceByColumn(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.reduction_mean = reduction
def forward(self, logits, targets):
targets = targets.chunk(self.summa_dim, dim=0)[self.row_rank]
loss = _ParallelCrossEntropyLossFunction_2D.apply(
logits, targets,
)
if self.reduction_mean:
loss = _ReduceByColumn.apply(loss) / self.summa_dim
dist_loss = loss.mean()
return dist_loss

View File

@ -1,124 +0,0 @@
import torch
import torch.distributed as dist
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization, \
get_tesseract_dim_dep_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
class _ParallelCrossEntropyLossFunction_2p5D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
def forward(ctx, logits, targets):
# logits: [b/dq, h/q]
# loss: [b/dq]
# targets: [b/dq, h/q]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0, end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=get_current_device())
grad_2d[arange_1d,
masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
class _ReduceByColDep(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
return input_
@staticmethod
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
return input_
@staticmethod
def backward(ctx, grad_output):
return grad_output
@LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss):
"""Cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
assert_tesseract_initialization()
self.xz_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.reduction_mean = reduction
def forward(self, logits, targets):
targets = targets.chunk(self.tesseract_dim *
self.tesseract_dep, dim=0)[self.xz_rank]
loss = _ParallelCrossEntropyLossFunction_2p5D.apply(
logits, targets,
)
if self.reduction_mean:
loss = _ReduceByColDep.apply(
loss) / self.tesseract_dim / self.tesseract_dep
dist_loss = loss.mean()
return dist_loss

View File

@ -1,183 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch
import torch.distributed as dist
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
get_last_group,
get_parallel_mode_from_env)
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.nn.modules.loss import _Loss
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
"""
Adapted from megatron.mpu.cross_entropy
loss[i] = -logits[i][targets] + log(sum(exp(logits[i])))
"""
@staticmethod
def forward(ctx, logits, targets, depth, output_parallel_mode):
# logits: [b/q^2, c/q]
# labels: [b/q^2]
# loss: [b/q^2]
logits_max = torch.max(logits, dim=-1)[0]
dist.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(output_parallel_mode))
# Subtract the maximum value.
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size_per_partition = logits.size()[-1]
rank = gpc.get_local_rank(output_parallel_mode)
vocab_start = rank * vocab_size_per_partition
vocab_end = (rank + 1) * vocab_size_per_partition - 1
# loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0,
end=logits.size()[0],
device=get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(
targets)
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits,
group=gpc.get_group(output_parallel_mode))
# Loss = log(sum(exp(logits))) - predicted-logit.
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=-1)
dist.all_reduce(sum_exp_logits,
group=gpc.get_group(output_parallel_mode))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
input_grad = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0,
end=grad_2d.size()[0],
device=get_current_device())
grad_2d[arange_1d,
masked_target] -= (1.0 - target_mask.view(-1).float())
input_grad.mul_(output_grad.unsqueeze(dim=-1))
return input_grad, None, None, None
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism
:param depth: depth for 3D parallelism
:type depth: int
:param input_parallel_mode: parallel mode for input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode for weight
:type weight_parallel_mode: ParallelMode
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(
self,
# input_parallel_mode,
# weight_parallel_mode,
reduction=True,
label_smoothing=0.0):
super().__init__()
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
self.reduction_mean = reduction
def forward(self, logits, targets):
# split label partition from the entire batch
batch_size = targets.size(0)
targets = torch.chunk(targets, self.depth, dim=0)[self.weight_rank]
targets = torch.chunk(targets, self.depth, dim=0)[self.input_rank]
loss = _ParallelCrossEntropyLossFunction_3D.apply(
logits, targets, self.depth, self.output_parallel_mode)
if self.reduction_mean:
loss = loss.sum()
loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
loss /= batch_size
return loss
# @LOSSES.register_module
# class LabelSmoothingCrossEntropy3D(_Loss):
# """
# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
# :param input_parallel_mode: parallel mode for input tensor
# :type input_parallel_mode: ParallelMode
# :param weight_parallel_mode: parallel mode for weight
# :type weight_parallel_mode: ParallelMode
# :param smoothing: label smoothing value, defaults to 0.1
# :type smoothing: float
# :param reduction: whether to average the loss, defaults to True
# :type reduction: bool, optional
# """
# def __init__(self,
# input_parallel_mode,
# weight_parallel_mode,
# smoothing=0.1,
# reduction=True):
# super().__init__()
# assert smoothing < 1.0
# self.smoothing = smoothing
# self.confidence = 1. - smoothing
# self.depth = get_depth_from_env()
# self.input_parallel_mode = input_parallel_mode
# self.weight_parallel_mode = weight_parallel_mode
# self.output_parallel_mode = get_last_group(input_parallel_mode,
# weight_parallel_mode)
# self.reduction_mean = reduction
# def forward(self, logits, targets):
# # split label partition from the entire batch
# j = gpc.get_local_rank(self.input_parallel_mode)
# i = gpc.get_local_rank(self.weight_parallel_mode)
# targets = torch.chunk(targets, self.depth, dim=0)[i]
# targets = torch.chunk(targets, self.depth, dim=0)[j]
# exp_logits = torch.exp(logits)
# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
# self.output_parallel_mode, False)
# log_probs = torch.log(sum_exp_logits) - logits
# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
# logits, targets, self.depth, self.output_parallel_mode)
# smooth_loss = -log_probs.mean(dim=-1)
# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
# if self.reduction_mean:
# loss = loss.sum()
# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
# loss /= batch_size
# return loss

View File

@ -0,0 +1,30 @@
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True, *args, **kwargs):
super().__init__()
assert_summa_initialization()
self.reduction_mean = reduction
self.loss_args = args
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_2d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_2d.apply(loss)
loss /= batch_size
return loss

View File

@ -0,0 +1,29 @@
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss):
"""Cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True, *args, **kwargs):
super().__init__()
assert_tesseract_initialization()
self.reduction_mean = reduction
self.loss_args = args
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_2p5d(targets)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_2p5d.apply(loss)
loss /= batch_size
return loss

View File

@ -0,0 +1,38 @@
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module
class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism
:param depth: depth for 3D parallelism
:type depth: int
:param input_parallel_mode: parallel mode for input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode for weight
:type weight_parallel_mode: ParallelMode
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True, *args, **kwargs):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.reduction_mean = reduction
self.loss_args = args
self.loss_kwargs = kwargs
def forward(self, logits, targets):
batch_size = targets.size(0)
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean:
loss = loss.sum()
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
loss /= batch_size
return loss

View File

@ -0,0 +1,24 @@
from torch import nn
from ._utils import calc_acc
from .accuracy_2d import Accuracy2D
from .accuracy_2p5d import Accuracy2p5D
from .accuracy_3d import Accuracy3D
_parallel_accuracy = {
'2d': Accuracy2D,
'2.5d': Accuracy2p5D,
'3d': Accuracy3D,
}
class Accuracy(nn.Module):
def __init__(self, tensor_parallel: str = None):
super().__init__()
if tensor_parallel in [None, '1d']:
self.acc = calc_acc
else:
self.acc = _parallel_accuracy[tensor_parallel]()
def forward(self, *args):
return self.acc(*args)

View File

@ -0,0 +1,6 @@
import torch
def calc_acc(logits, targets):
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(targets == preds)
return correct

View File

@ -0,0 +1,17 @@
import torch
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from torch import nn
from ._utils import calc_acc
class Accuracy2D(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, targets):
with torch.no_grad():
targets = split_batch_2d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2d.apply(correct)
return correct

View File

@ -0,0 +1,17 @@
import torch
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from torch import nn
from ._utils import calc_acc
class Accuracy2p5D(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, targets):
with torch.no_grad():
targets = split_batch_2p5d(targets)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_2p5d.apply(correct)
return correct

View File

@ -0,0 +1,21 @@
import torch
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from torch import nn
from ._utils import calc_acc
class Accuracy3D(nn.Module):
def __init__(self):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
def forward(self, logits, targets):
with torch.no_grad():
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
correct = calc_acc(logits, targets)
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
return correct

View File

@ -164,31 +164,35 @@ class Trainer:
if epoch is None:
progress = tqdm(progress, desc='[Train]')
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]')
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='train-epoch')
self._call_timer(action='start', item='Train-epoch')
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
self._call_timer(action='start', item='Train-step')
# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=False, return_loss=True)
self.engine.step()
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self._cur_step += 1
if display_progress:
if 'step_metrics' in self.states:
progress.set_postfix(**self.states['step_metrics'])
# stop when max iter is reached
if self._exceed_max_step():
break
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
self._call_timer(action='stop', item='Train-epoch', keep_in_history=True)
self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='train-step')
self._call_timer(action='reset', item='Train-step')
def _eval(self,
test_dataloader: DataLoader,
@ -206,25 +210,30 @@ class Trainer:
if display_progress:
desc = 'Evaluation'
if epoch is not None:
desc = '[Epoch %d val]' % epoch
desc = '[Epoch %d / Test]' % epoch
progress = tqdm(progress, desc=desc)
self._call_hooks('before_test_epoch')
self._call_timer(action='start', item='test-epoch')
self._call_timer(action='start', item='Test-epoch')
with torch.no_grad():
for _ in progress:
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='test-step')
self._call_timer(action='start', item='Test-step')
logits, label, loss = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=True)
self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
if display_progress:
if 'step_metrics' in self.states:
progress.set_postfix(**self.states['step_metrics'])
self._call_timer(action='stop', item='Test-epoch', keep_in_history=True)
self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='test-step')
self._call_timer(action='reset', item='test-epoch')
self._call_timer(action='reset', item='Test-step')
self._call_timer(action='reset', item='Test-epoch')
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step >= self._max_steps
@ -317,7 +326,7 @@ class Trainer:
ranks=[0])
break
self._call_hooks('after_train')
self._call_timer('reset', 'train-epoch')
self._call_timer('reset', 'Train-epoch')
def evaluate(self,
test_dataloader: DataLoader,
@ -374,4 +383,4 @@ class Trainer:
data_iter = iter(simple_dataloader)
output, _, _ = self.schedule.forward_backward_step(
self.engine, data_iter, forward_only=True, return_loss=False)
return output
return output

View File

@ -1,15 +1,12 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import (LossHook, Accuracy2DHook, AccuracyHook, MetricHook,
Accuracy1DHook, Accuracy2p5DHook, Accuracy3DHook)
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
from ._checkpoint_hook import LoadCheckpointHook, SaveCheckpointHook
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
TensorboardHook)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
__all__ = [
'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'Accuracy1DHook', 'Accuracy2p5DHook', 'Accuracy3DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
'LRSchedulerHook'
'BaseHook', 'MetricHook', 'LoadCheckpointHook', 'SaveCheckpointHook', 'LossHook', 'AccuracyHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook',
'ThroughputHook', 'LogMetricByStepHook'
]

View File

@ -16,11 +16,11 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
from ._base_hook import BaseHook
def _format_number(val):
def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.5g}'
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.5g}'
return f'{val.item():.{prec}g}'
return val
@ -37,6 +37,24 @@ class LogByEpochHook(BaseHook):
return trainer.cur_epoch % self._interval == 0
@HOOKS.register_module
class LogMetricByStepHook(BaseHook):
def __init__(self, priority: int = 10):
super().__init__(priority)
def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
@HOOKS.register_module
class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log.
@ -61,7 +79,7 @@ class LogMetricByEpochHook(LogByEpochHook):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append(
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ', '.join(msg)
msg = ' | '.join(msg)
return msg
def after_train_epoch(self, trainer):
@ -69,15 +87,15 @@ class LogMetricByEpochHook(LogByEpochHook):
msg = self._get_str(trainer=trainer, mode='train')
if self._is_rank_to_log:
self.logger.info(
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}')
# f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
def after_test_epoch(self, trainer):
if self._is_epoch_to_log(trainer):
msg = self._get_str(trainer=trainer, mode='test')
if self._is_rank_to_log:
self.logger.info(
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
# f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
@ -131,8 +149,7 @@ class TensorboardHook(BaseHook):
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def _log_by_iter(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
@ -141,16 +158,14 @@ class TensorboardHook(BaseHook):
val = metric_calculator.get_last_step_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
trainer.cur_step)
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
def _log_by_epoch(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
trainer.cur_step)
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
def after_test_iter(self, trainer, *args):
self._log_by_iter(trainer, mode='test')
@ -178,15 +193,13 @@ class LogTimingByEpochHook(LogByEpochHook):
:param log_eval: Whether writes in evaluation
:type log_eval: bool, optional
"""
def __init__(self,
timer: MultiTimer,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
ignore_num_train_steps: int = 0
) -> None:
ignore_num_train_steps: int = 0) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._timer = timer
self._log_eval = log_eval
@ -197,40 +210,39 @@ class LogTimingByEpochHook(LogByEpochHook):
self._ignore_num_train_steps = ignore_num_train_steps
self._is_train_step_history_trimmed = False
def _get_message(self):
def _get_message(self, mode):
msg = []
for timer_name, timer in self._timer:
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
if timer_name == 'train-step' and not self._is_train_step_history_trimmed:
timer._history = timer._history[self._ignore_num_train_steps:]
self._is_train_step_history_trimmed = True
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s')
else:
msg.append(
f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
if timer_name.startswith(mode):
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
if timer_name == 'Train-step' and not self._is_train_step_history_trimmed:
timer._history = timer._history[self._ignore_num_train_steps:]
self._is_train_step_history_trimmed = True
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s'
)
else:
msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
msg = ', '.join(msg)
msg = ' | '.join(msg)
return msg
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message()
self.logger.info(
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={trainer.steps_per_epoch}')
msg = self._get_message('Train')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}')
def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
msg = self._get_message()
self.logger.info(
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
msg = self._get_message('Test')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
@HOOKS.register_module
@ -246,14 +258,12 @@ class LogMemoryByEpochHook(LogByEpochHook):
:param log_eval: Whether writes in evaluation
:type log_eval: bool, optional
"""
def __init__(self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False
) -> None:
report_cpu: bool = False) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
@ -262,20 +272,16 @@ class LogMemoryByEpochHook(LogByEpochHook):
"""Resets before training.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage('before-train', self.logger)
report_memory_usage('Before-train', self.logger)
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage(
f'After Train - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
self.logger)
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger)
def after_test(self, trainer):
"""Reports after testing.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
report_memory_usage(
f'After Test - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
self.logger)
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger)

View File

@ -1,9 +1,7 @@
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from ..metric import LearningRate
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module
@ -19,28 +17,28 @@ class LRSchedulerHook(MetricHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
def __init__(
self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(priority=priority)
self.by_epoch = by_epoch
self.lr_scheduler = lr_scheduler
self.store_lr_in_state = store_lr_in_state
def after_hook_is_attached(self, trainer):
trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])

View File

@ -1,11 +1,209 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.distributed as dist
from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.utils import is_no_pp_or_last_stage
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and it's always used with
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
collector works.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
# is the metric only read for the full epoch
self._epoch_only = epoch_only
@property
def epoch_only(self):
"""Returns :attr:`epoch_only`.
"""
return self._epoch_only
@abstractmethod
def reset(self) -> None:
"""Resets the metric to it's initial state.
By default, this is called at the start of each epoch.
"""
pass
@abstractmethod
def update(self, *args, **kwargs) -> None:
"""Updates the metric's state using the passed batch output.
By default, this is called once for each batch.
"""
pass
@abstractmethod
def get_last_step_value(self):
"""Returns the metric value in the last iteration.
"""
pass
@abstractmethod
def get_accumulated_value(self):
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each epoch.
:return: the actual quantity of interest
:rtype: Any
"""
pass
@staticmethod
@abstractmethod
def is_better(a, b) -> bool:
"""Compares a and b, and returns whether a is better than b
:return: The result of comparison
:rtype: bool
"""
pass
class LossMetric(Metric):
"""A metric collector for loss.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device())
self.count = 0
def reset(self) -> None:
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
"""
self.last_step_loss.zero_()
self.accum_loss.zero_()
self.count = 0
def update(self, loss) -> None:
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
It expects the output has loss.
:param loss: Current loss of the output
"""
# expect output to be logits, label and loss
loss_ = loss.detach()
self.last_step_loss.copy_(loss_)
self.accum_loss.add_(loss_)
self.count += 1
def get_accumulated_value(self):
"""Returns accumulated loss.
"""
if gpc.is_initialized(ParallelMode.DATA):
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
self.accum_loss.div_(self.count)
return self.accum_loss.item()
def get_last_step_value(self):
"""Returns :attr:`last_step_loss`.
"""
return self.last_step_loss
def is_better(a, b):
return a < b
class LearningRateMetric(Metric):
"""A metric collector for learning rate.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = initial_lr
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self):
return self.lr
def get_accumulated_value(self):
return self.lr
def is_better(a, b) -> bool:
pass
class AccuracyMetric(Metric):
"""A metric collector for accuracy. It only works for classification
tasks.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device())
def reset(self) -> None:
self.last_step_sum.zero_()
self.last_step_correct.zero_()
self.accumulated_sum.zero_()
self.accumulated_correct.zero_()
def update(self, logits, targets) -> None:
"""Updates last step accuracy and accumulated accuracy with current logits
and labels. It expects the output has logits and labels.
:param logits: The logits output of the model
:param label: The labels of the input data
"""
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(targets, (list, tuple)):
targets = targets[0]
# update
correct = self.acc(logits, targets)
self.last_step_sum.fill_(targets.size(0))
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
def get_last_step_value(self):
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return (self.last_step_correct / self.last_step_sum).item()
def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
return (self.accumulated_correct / self.accumulated_sum).item()
def is_better(a, b) -> bool:
return a > b
class MetricHook(BaseHook):
@ -19,10 +217,10 @@ class MetricHook(BaseHook):
:type trainer: Trainer
:type priority: int
"""
def __init__(self,
priority: int,
):
def __init__(
self,
priority: int,
):
super().__init__(priority)
self._is_stage_to_compute = is_no_pp_or_last_stage()
@ -40,7 +238,6 @@ class LossHook(MetricHook):
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, priority: int = 0):
super().__init__(priority)
@ -48,14 +245,12 @@ class LossHook(MetricHook):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
self.train_loss = LossMetric(epoch_only=False)
self.test_loss = LossMetric(epoch_only=True)
# register the metric calculator
trainer.states['metrics']['train'][
self.train_loss.__class__.__name__] = self.train_loss
trainer.states['metrics']['test'][
self.test_loss.__class__.__name__] = self.test_loss
trainer.states['metrics']['train']['Loss'] = self.train_loss
trainer.states['metrics']['test']['Loss'] = self.test_loss
def before_train_epoch(self, trainer):
if self._is_stage_to_compute:
@ -74,124 +269,6 @@ class LossHook(MetricHook):
self.test_loss.update(loss)
@HOOKS.register_module
class Accuracy1DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy1D`.
It acts the same as :class:`AccuracyHook`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy1D(epoch_only=True)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy2D`.
It acts the same as :class:`AccuracyHook`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy3DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy3D`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int
"""
def __init__(self,
priority: int = 10):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
if self._is_stage_to_compute:
self.metric = Accuracy3D(epoch_only=True)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class AccuracyHook(MetricHook):
"""Specialized hook class for :class:`Accuracy`.
@ -201,22 +278,87 @@ class AccuracyHook(MetricHook):
:type trainer: Trainer
:type priority: int
"""
def __init__(self, priority: int = 0):
def __init__(self, accuracy_func: Callable, priority: int = 0):
super().__init__(priority)
self.accuracy_func = accuracy_func
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True)
self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
trainer.states['metrics']['test']['Accuracy'] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
self.metric.update(logits, targets)
class ThroughputMetric(Metric):
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device())
def reset(self) -> None:
self.accumulated_num_samples.zero_()
self.accumulated_used_time.zero_()
self.last_step_num_samples.zero_()
self.last_step_used_time.zero_()
def update(self, tensor, time) -> None:
if isinstance(tensor, (list, tuple)):
tensor = tensor[0]
self.last_step_num_samples.fill_(tensor.size(0))
self.last_step_used_time.fill_(time)
self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self):
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item()
def get_accumulated_value(self):
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
def is_better(a, b) -> bool:
pass
@HOOKS.register_module
class ThroughputHook(MetricHook):
def __init__(self, priority: int = 10):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = ThroughputMetric(epoch_only=True)
# register the metric
trainer.states['metrics']['train']['Throughput'] = self.metric
trainer.states['metrics']['test']['Throughput'] = self.metric
def before_train_epoch(self, trainer):
self.metric.reset()
def after_train_iter(self, trainer, logits, targets, *args):
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
self.metric.reset()
def after_test_iter(self, trainer, logits, targets, *args):
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())

View File

@ -1,356 +0,0 @@
import os
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
from colossalai.communication import all_gather
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer._parallel_utilities import _gather
from colossalai.nn.layer.parallel_3d._utils import (get_last_group,
get_parallel_mode_from_env)
from colossalai.utils import get_current_device
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and it's always used with
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
collector works.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
# is the metric only read for the full epoch
self._epoch_only = epoch_only
@property
def epoch_only(self):
"""Returns :attr:`epoch_only`.
"""
return self._epoch_only
@abstractmethod
def reset(self) -> None:
"""Resets the metric to it's initial state.
By default, this is called at the start of each epoch.
"""
pass
@abstractmethod
def update(self, *args, **kwargs) -> None:
"""Updates the metric's state using the passed batch output.
By default, this is called once for each batch.
"""
pass
@abstractmethod
def get_last_step_value(self):
"""Returns the metric value in the last iteration.
"""
pass
@abstractmethod
def get_accumulated_value(self):
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each epoch.
:return: the actual quantity of interest
:rtype: Any
"""
pass
@staticmethod
@abstractmethod
def is_better(a, b) -> bool:
"""Compares a and b, and returns whether a is better than b
:return: The result of comparison
:rtype: bool
"""
pass
class Loss(Metric):
"""A metric collector for loss.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device())
self.count = 0
def reset(self) -> None:
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
"""
self.last_step_loss.zero_()
self.accum_loss.zero_()
self.count = 0
def update(self, loss) -> None:
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
It expects the output has loss.
:param loss: Current loss of the output
"""
# expect output to be logits, label and loss
loss_ = loss.detach()
self.last_step_loss.copy_(loss_)
self.accum_loss.add_(loss_)
self.count += 1
def get_accumulated_value(self):
"""Returns accumulated loss.
"""
if gpc.is_initialized(ParallelMode.DATA):
dist.all_reduce(self.accum_loss,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.DATA))
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
self.accum_loss.div_(self.count)
return self.accum_loss.item()
def get_last_step_value(self):
"""Returns :attr:`last_step_loss`.
"""
return self.last_step_loss
def is_better(a, b):
return a < b
class LearningRate(Metric):
"""A metric collector for learning rate.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = 0.
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self):
return self.lr
def get_accumulated_value(self):
return self.lr
def is_better(a, b) -> bool:
pass
class Accuracy(Metric):
"""A metric collector for accuracy. It only works for classification
tasks.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
self.last_step_sum = torch.zeros(1, device=get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device())
def reset(self) -> None:
self.last_step_sum.zero_()
self.last_step_correct.zero_()
self.accumulated_sum.zero_()
self.accumulated_correct.zero_()
def update(self, logits, label) -> None:
"""Updates last step accuracy and accumulated accuracy with current logits
and labels. It expects the output has logits and labels.
:param logits: The logits output of the model
:param label: The labels of the input data
"""
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(label, (list, tuple)):
label = label[0]
# update
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(label == preds)
self.last_step_sum.fill_(label.size(0))
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
def get_last_step_value(self):
dist.all_reduce(self.last_step_sum,
group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(self.last_step_correct,
group=gpc.get_group(ParallelMode.DATA))
return (self.last_step_sum / self.last_step_correct).item()
def get_accumulated_value(self):
dist.all_reduce(self.accumulated_sum,
group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(self.accumulated_correct,
group=gpc.get_group(ParallelMode.DATA))
return (self.accumulated_correct / self.accumulated_sum).item()
def is_better(a, b) -> bool:
return a > b
class Accuracy2D(Accuracy):
"""A metric collector for accuracy. It only works for classification
tasks. This class is the same as :class:`Accuracy` but used in 2D
model parallelism.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
def update(self, logits, label) -> None:
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(label, (list, tuple)):
label = label[0]
logits = _gather(logits, ParallelMode.PARALLEL_2D_ROW, 1)
logits = _gather(
logits,
ParallelMode.PARALLEL_2D_COL,
0,
)
# update
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(label == preds)
self.last_step_sum.fill_(label.size(0))
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
class Accuracy1D(Accuracy):
"""A metric collector for accuracy. It only works for classification
tasks. This class is the same as :class:`Accuracy` but used in 2D
model parallelism.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
def update(self, logits, label) -> None:
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(label, (list, tuple)):
label = label[0]
logits = _gather(
logits,
ParallelMode.PARALLEL_1D,
1
)
# update
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(label == preds)
self.last_step_sum.fill_(label.size(0))
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
class Accuracy2p5D(Accuracy):
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
def update(self, logits, label) -> None:
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(label, (list, tuple)):
label = label[0]
logits = _gather(logits, ParallelMode.PARALLEL_2P5D_ROW, 1)
logits = _gather(
logits,
ParallelMode.PARALLEL_2P5D_COL,
0,
)
logits = _gather(
logits,
ParallelMode.PARALLEL_2P5D_DEP,
0,
)
# update
preds = torch.argmax(logits, dim=-1)
correct = torch.sum(label == preds)
self.last_step_sum.fill_(label.size(0))
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
def is_better(a, b) -> bool:
return a > b
class Accuracy3D(Accuracy):
"""A metric collector for accuracy. It only works for classification
tasks. This class is the same as :class:`Accuracy` but used in 3D
model parallelism.
:param input_parallel_mode: The parallel mode of the input, generally it should be `ParallelMode.PARALLEL_3D_OUTPUT`
:type input_parallel_mode: `ParallelMode`
:param weight_parallel_mode: The parallel mode of the weight, generally it should be `ParallelMode.PARALLEL_3D_WEIGHT`
:type weight_parallel_mode: `ParallelMode`
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only):
# input_parallel_mode, weight_parallel_mode):
super().__init__(epoch_only=epoch_only)
self.depth = int(os.environ['DEPTH_3D'])
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
def update(self, logits, target):
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(target, (list, tuple)):
target = target[0]
batch_size = target.size(0)
j = gpc.get_local_rank(self.input_parallel_mode)
i = gpc.get_local_rank(self.weight_parallel_mode)
target = torch.chunk(target, self.depth, dim=0)[i]
target = torch.chunk(target, self.depth, dim=0)[j]
logits = all_gather(logits, -1, self.output_parallel_mode)
logits = torch.cat(logits, dim=-1)
prediction = torch.argmax(logits, dim=-1)
correct = torch.sum(prediction == target)
dist.all_reduce(correct, group=gpc.get_group(self.input_parallel_mode))
dist.all_reduce(correct,
group=gpc.get_group(self.weight_parallel_mode))
self.last_step_sum.fill_(batch_size)
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct

View File

@ -48,14 +48,14 @@ def report_memory_usage(message, logger=None, report_cpu=False):
gpu_cached = bytes_to_MB(torch.cuda.memory_reserved())
gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved())
full_log = f"{message} - GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, \
cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
if report_cpu:
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
gc.collect()
vm_stats=psutil.virtual_memory()
vm_used=bytes_to_MB(vm_stats.total - vm_stats.available)
vm_stats = psutil.virtual_memory()
vm_used = bytes_to_MB(vm_stats.total - vm_stats.available)
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
if logger is None:

View File

@ -13,9 +13,7 @@ from tqdm import tqdm
def main():
colossalai.launch_from_torch(config='./config.py',
host='localhost',
port=29500)
colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger()

View File

@ -1,22 +1,22 @@
import os
from pathlib import Path
from colossalai.logging import get_dist_logger
import colossalai
import torch
import os
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.logging import get_dist_logger
from colossalai.nn import CosineAnnealingLR
from colossalai.nn.metric import Accuracy
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from torchvision import transforms
from colossalai.trainer import hooks, Trainer
from torchvision.datasets import CIFAR10
from torchvision.models import resnet34
from colossalai.nn import CosineAnnealingLR
from tqdm import tqdm
def main():
colossalai.launch_from_torch(config='./config.py',
host='localhost',
port=29500)
colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger()
@ -93,7 +93,7 @@ def main():
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
hooks.AccuracyHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LogMemoryByEpochHook(logger),
hooks.LogTimingByEpochHook(timer, logger),

View File

@ -19,5 +19,5 @@ dataset = dict(
)
gradient_accumulation=2
gradient_clipping=1.0
clip_grad_norm=1.0

View File

@ -20,4 +20,4 @@ dataset = dict(
)
gradient_accumulation=1
gradient_clipping=1.0
clip_grad_norm=1.0

View File

@ -1,3 +1,4 @@
from colossalai.nn.metric import Accuracy
import torch
import colossalai
from colossalai.core import global_context as gpc
@ -40,9 +41,7 @@ def build_dataset_test():
)
def main():
colossalai.launch_from_torch(config='./le_config.py',
host='localhost',
port=29500)
colossalai.launch_from_torch(config='./le_config.py')
# get logger
logger = get_dist_logger()
@ -81,7 +80,7 @@ def main():
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),

View File

@ -41,9 +41,7 @@ def build_dataset_test():
)
def main():
colossalai.launch_from_torch(config='./config.py',
host='localhost',
port=29500)
colossalai.launch_from_torch(config='./config.py')
# get logger
logger = get_dist_logger()

View File

@ -39,11 +39,7 @@ In your training script:
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_torch(config=args.config,
host=args.host,
port=args.port,
backend=args.backend
)
colossalai.launch_from_torch(config=args.config)
```
In your terminal

View File

@ -11,7 +11,7 @@ fp16 = dict(
)
gradient_accumulation = 16
gradient_clipping = 1.0
clip_grad_norm = 1.0
dali = dict(
# root='./dataset/ILSVRC2012_1k',

View File

@ -2,6 +2,7 @@ import glob
from math import log
import os
import colossalai
from colossalai.nn.metric import Accuracy
import torch
from colossalai.context import ParallelMode
@ -54,11 +55,15 @@ def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
args = parser.parse_args()
# launch from slurm batch job
colossalai.launch_from_slurm(config=args.config,
host=args.host,
port=args.port,
backend=args.backend
)
# launch from torch
# colossalai.launch_from_torch(config=args.config)
# get logger
logger = get_dist_logger()
@ -91,7 +96,7 @@ def main():
# build hooks
hook_list = [
hooks.LossHook(),
hooks.AccuracyHook(),
hooks.AccuracyHook(accuracy_func=Accuracy()),
hooks.LogMetricByEpochHook(logger),
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
TotalBatchsizeHook(),

View File

@ -0,0 +1 @@
from .vit import *

View File

@ -1,208 +0,0 @@
import torch
from torch import nn
from colossalai import nn as col_nn
from colossalai.context import ParallelMode
from colossalai.registry import MODELS
__all__ = [
'VisionTransformer3D',
'vit_tiny_1d_patch4_32',
'vit_tiny_1d_patch16_224',
'vit_tiny_1d_patch16_384',
'vit_small_1d_patch16_224',
'vit_small_1d_patch16_384',
'vit_small_1d_patch32_224',
'vit_small_1d_patch32_384',
'vit_base_1d_patch16_224',
'vit_base_1d_patch16_384',
'vit_base_1d_patch32_224',
'vit_base_1d_patch32_384',
'vit_large_1d_patch16_224',
'vit_large_1d_patch16_384',
'vit_large_1d_patch32_224',
'vit_large_1d_patch32_384',
]
class ViTBlock1D(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
hidden_dim: int,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = col_nn.ViTSelfAttention1D(dim, num_heads, attn_drop, drop)
self.drop_path = col_nn.VanillaViTDropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = col_nn.ViTMLP1D(dim, 1, drop, 'gelu')
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODELS.register_module
class VisionTransformer1D(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
num_heads: int = 12,
embed_dim: int = 768,
hidden_dim: int = 3072,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = col_nn.ViTPatchEmbedding1D(
img_size,
patch_size,
in_chans,
embed_dim,
drop_rate,
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(*[
ViTBlock1D(embed_dim, num_heads, hidden_dim,
drop_rate, attn_drop_rate, dpr[i])
for i in range(depth)
])
self.norm = nn.LayerNorm(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT)
self.head = col_nn.ViTHead1D(hidden_dim, num_classes)
self.init_weights()
def init_weights(self):
pass
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = self.norm(x)
x = self.head(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer1D(**model_kwargs)
return model
@MODELS.register_module
def vit_tiny_1d_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_1d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=192,
depth=12, num_heads=3, hidden_dim=768, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_1d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_1d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=384,
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_1d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_1d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=384,
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_1d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_1d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_1d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_3d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=768,
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_1d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_3d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024,
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_1d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_1d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=1024,
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_1d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)

View File

@ -1 +0,0 @@
from .vit import *

View File

@ -1,219 +0,0 @@
from colossalai.context import ParallelMode, seed
from colossalai import nn as clsl_nn
from colossalai.registry import MODELS
from torch import nn
import torch
__all__ = [
'VisionTransformer2D',
'vit_tiny_2d_patch4_32',
'vit_tiny_2d_patch16_224',
'vit_tiny_2d_patch16_384',
'vit_small_2d_patch16_224',
'vit_small_2d_patch16_384',
'vit_small_2d_patch32_224',
'vit_small_2d_patch32_384',
'vit_base_2d_patch16_224',
'vit_base_2d_patch16_384',
'vit_base_2d_patch32_224',
'vit_base_2d_patch32_384',
'vit_large_2d_patch16_224',
'vit_large_2d_patch16_384',
'vit_large_2d_patch32_224',
'vit_large_2d_patch32_384',
]
class ViTBlock2D(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int = 4,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: str = 'gelu'):
super().__init__()
self.norm1 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
self.attn = clsl_nn.ViTSelfAttention2D(dim, num_heads, attn_drop, drop)
self.drop_path = clsl_nn.VanillaViTDropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
self.mlp = clsl_nn.ViTMLP2D(dim, mlp_ratio, act_layer, drop)
def forward(self, x):
y = self.attn(self.norm1(x))
with seed(ParallelMode.TENSOR):
x = x + self.drop_path(y)
y = self.mlp(self.norm2(x))
with seed(ParallelMode.TENSOR):
x = x + self.drop_path(y)
return x
@MODELS.register_module
class VisionTransformer2D(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: int = 4,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
act_layer: str = 'gelu'):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = clsl_nn.ViTPatchEmbedding2D(
img_size, patch_size, embed_dim, in_chans
)
self.splitter = clsl_nn.ViTInputSplitter2D()
self.token_fuser = clsl_nn.ViTTokenFuser2D(
img_size, patch_size, embed_dim, drop_rate
)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(*[
ViTBlock2D(embed_dim, num_heads, mlp_ratio, drop_rate,
attn_drop_rate, dpr[i], act_layer)
for i in range(depth)
])
self.norm = clsl_nn.LayerNorm2D(embed_dim, eps=1e-6)
self.head = clsl_nn.ViTHead2D(self.num_features, num_classes) if num_classes > 0 \
else nn.Identity()
self.init_weights()
def init_weights(self):
pass
def forward(self, x):
x = self.patch_embed(x)
x = self.splitter(x)
x = self.token_fuser(x)
x = self.blocks(x)
x = self.norm(x)
x = self.head(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer2D(**model_kwargs)
return model
@MODELS.register_module
def vit_tiny_2d_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
depth=6, num_heads=8, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=192,
depth=12, num_heads=3, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=192,
depth=12, num_heads=3, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=384,
depth=12, num_heads=6, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=768,
depth=12, num_heads=12, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_2d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=1024,
depth=24, num_heads=16, **kwargs)
return _create_vit_model(**model_kwargs)

View File

@ -1 +0,0 @@
from .vit import *

View File

@ -1,209 +0,0 @@
import torch
from torch import nn
from colossalai import nn as col_nn
from colossalai.context import ParallelMode
from colossalai.registry import MODELS
__all__ = [
'VisionTransformer3D',
'vit_tiny_3d_patch4_32',
'vit_tiny_3d_patch16_224',
'vit_tiny_3d_patch16_384',
'vit_small_3d_patch16_224',
'vit_small_3d_patch16_384',
'vit_small_3d_patch32_224',
'vit_small_3d_patch32_384',
'vit_base_3d_patch16_224',
'vit_base_3d_patch16_384',
'vit_base_3d_patch32_224',
'vit_base_3d_patch32_384',
'vit_large_3d_patch16_224',
'vit_large_3d_patch16_384',
'vit_large_3d_patch32_224',
'vit_large_3d_patch32_384',
]
class ViTBlock3D(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
hidden_dim: int,
drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.):
super().__init__()
self.norm1 = col_nn.LayerNorm3D(
dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
self.attn = col_nn.ViTSelfAttention3D(dim, num_heads, attn_drop, drop)
self.drop_path = col_nn.VanillaViTDropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm3D(dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
self.mlp = col_nn.ViTMLP3D(hidden_dim, 1, drop, 'gelu')
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODELS.register_module
class VisionTransformer3D(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
num_heads: int = 12,
embed_dim: int = 768,
hidden_dim: int = 3072,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.patch_embed = col_nn.ViTPatchEmbedding3D(
img_size,
patch_size,
in_chans,
embed_dim,
drop_rate,
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.Sequential(*[
ViTBlock3D(embed_dim, num_heads, hidden_dim,
drop_rate, attn_drop_rate, dpr[i])
for i in range(depth)
])
self.norm = col_nn.LayerNorm3D(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
ParallelMode.PARALLEL_3D_WEIGHT)
self.head = col_nn.ViTHead3D(hidden_dim, num_classes)
self.init_weights()
def init_weights(self):
pass
def forward(self, x):
x = self.patch_embed(x)
x = self.blocks(x)
x = self.norm(x)
x = self.head(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer3D(**model_kwargs)
return model
@MODELS.register_module
def vit_tiny_3d_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_3d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=192,
depth=12, num_heads=3, hidden_dim=768, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_3d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_3d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=384,
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_3d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_3d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=384,
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_3d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_3d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768,
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_3d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_3d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=768,
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_3d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_3d_patch16_224(**kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024,
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_3d_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16,
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_3d_patch32_224(**kwargs):
model_kwargs = dict(patch_size=32, embed_dim=1024,
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_3d_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32,
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
return _create_vit_model(**model_kwargs)

487
model_zoo/vit/vit.py Normal file
View File

@ -0,0 +1,487 @@
import math
from typing import Callable
import torch
from colossalai import nn as col_nn
from colossalai.context import ParallelMode, seed
from colossalai.registry import LAYERS, MODELS
from colossalai.utils import checkpoint
from torch import dtype, nn
__all__ = [
'VisionTransformer',
'vit_lite_depth7_patch4_32',
'vit_tiny_patch4_32',
'vit_tiny_patch16_224',
'vit_tiny_patch16_384',
'vit_small_patch16_224',
'vit_small_patch16_384',
'vit_small_patch32_224',
'vit_small_patch32_384',
'vit_base_patch16_224',
'vit_base_patch16_384',
'vit_base_patch32_224',
'vit_base_patch32_384',
'vit_large_patch16_224',
'vit_large_patch16_384',
'vit_large_patch32_224',
'vit_large_patch32_384',
]
_init_rules = dict(
torch=dict(
embed=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
position_embed_initializer=col_nn.init.zeros_(),
),
transformer=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
),
head=dict(
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
),
),
jax=dict(
embed=dict(
weight_initializer=col_nn.init.lecun_normal_(),
bias_initializer=col_nn.init.zeros_(),
position_embed_initializer=col_nn.init.trunc_normal_(std=.02),
),
transformer=dict(
weight_initializer=col_nn.init.xavier_uniform_(),
bias_initializer=col_nn.init.normal_(std=1e-6),
),
head=dict(
weight_initializer=col_nn.init.zeros_(),
bias_initializer=col_nn.init.zeros_(),
),
),
)
@LAYERS.register_module
class ViTEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embedding_dim: int,
dropout: float,
dtype: dtype = None,
flatten: bool = True,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
self.patch_embed = col_nn.PatchEmbedding(img_size,
patch_size,
in_chans,
embedding_dim,
dtype=dtype,
flatten=flatten,
tensor_parallel=tensor_parallel,
**_init_rules[init_method]['embed'])
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.patch_embed(x)
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
return x
@LAYERS.register_module
class ViTSelfAttention(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
attention_dropout: float,
dropout: float,
bias: bool = True,
dtype: dtype = None,
checkpoint: bool = False,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
self.attention_head_size = dim // num_heads
self.checkpoint = checkpoint
self.tensor_parallel = tensor_parallel
self.query_key_value = col_nn.Linear(dim,
3 * dim,
dtype=dtype,
bias=bias,
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer'])
self.attention_dropout = nn.Dropout(attention_dropout)
self.dense = col_nn.Linear(dim,
dim,
dtype=dtype,
bias=True,
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer'])
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def _forward(self, x):
qkv = self.query_key_value(x)
all_head_size = qkv.shape[-1] // 3
num_attention_heads = all_head_size // self.attention_head_size
new_qkv_shape = qkv.shape[:-1] + \
(num_attention_heads, 3 * self.attention_head_size)
qkv = qkv.view(new_qkv_shape)
qkv = qkv.permute((0, 2, 1, 3))
q, k, v = torch.chunk(qkv, 3, dim=-1)
x = torch.matmul(q, k.transpose(-1, -2))
x = x / math.sqrt(self.attention_head_size)
x = self.softmax(x)
with seed(ParallelMode.TENSOR):
x = self.attention_dropout(x)
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
x = x.reshape(new_context_layer_shape)
x = self.dense(x)
if self.tensor_parallel == '1d':
x = self.dropout(x)
else:
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
return x
def _checkpoint_forward(self, x):
return checkpoint(self._forward, x)
def forward(self, x):
if self.checkpoint:
return self._checkpoint_forward(x)
else:
return self._forward(x)
@LAYERS.register_module
class ViTMLP(nn.Module):
def __init__(self,
dim: int,
mlp_ratio: int,
activation: Callable,
dropout: float,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
self.checkpoint = checkpoint
self.tensor_parallel = tensor_parallel
self.dense_1 = col_nn.Linear(dim,
mlp_ratio * dim,
dtype=dtype,
bias=bias,
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer'])
self.activation = activation
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
dim,
dtype=dtype,
bias=bias,
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
**_init_rules[init_method]['transformer'])
self.dropout = nn.Dropout(dropout)
def _forward(self, x):
x = self.dense_1(x)
x = self.activation(x)
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
x = self.dense_2(x)
if self.tensor_parallel == '1d':
x = self.dropout(x)
else:
with seed(ParallelMode.TENSOR):
x = self.dropout(x)
return x
def _checkpoint_forward(self, x):
return checkpoint(self._forward, x)
def forward(self, x):
if self.checkpoint:
return self._checkpoint_forward(x)
else:
return self._forward(x)
@LAYERS.register_module
class ViTHead(nn.Module):
def __init__(self,
dim: int,
num_classes: int,
representation_size: int = None,
dtype: dtype = None,
bias: bool = True,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
if representation_size:
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
if tensor_parallel == '1d':
tensor_parallel_kwargs['gather_output'] = True
self.representation = col_nn.Linear(dim,
representation_size,
bias=bias,
dtype=dtype,
**_init_rules[init_method]['head'],
**tensor_parallel_kwargs)
else:
self.representation = None
representation_size = dim
self.linear = col_nn.Classifier(representation_size,
num_classes,
dtype=dtype,
bias=bias,
tensor_parallel=tensor_parallel,
**_init_rules[init_method]['head'])
def forward(self, x):
x = x[:, 0]
if self.representation is not None:
x = self.representation(x)
x = self.linear(x)
return x
@LAYERS.register_module
class ViTBlock(nn.Module):
def __init__(self,
dim: int,
num_heads: int,
mlp_ratio: int,
activation: Callable,
attention_dropout: float = 0.,
dropout: float = 0.,
drop_path: float = 0.,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
self.attn = ViTSelfAttention(dim=dim,
num_heads=num_heads,
attention_dropout=attention_dropout,
dropout=dropout,
bias=bias,
dtype=dtype,
checkpoint=checkpoint,
init_method=init_method,
tensor_parallel=tensor_parallel)
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
self.mlp = ViTMLP(dim=dim,
mlp_ratio=mlp_ratio,
activation=activation,
dropout=dropout,
dtype=dtype,
bias=bias,
checkpoint=checkpoint,
init_method=init_method,
tensor_parallel=tensor_parallel)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@MODELS.register_module
class VisionTransformer(nn.Module):
def __init__(self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
num_heads: int = 12,
dim: int = 768,
mlp_ratio: int = 4,
attention_dropout: float = 0.,
dropout: float = 0.1,
drop_path: float = 0.,
activation: Callable = nn.functional.gelu,
representation_size: int = None,
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False,
init_method: str = 'torch',
tensor_parallel: str = None):
super().__init__()
embed = ViTEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embedding_dim=dim,
dropout=dropout,
dtype=dtype,
init_method=init_method,
tensor_parallel=tensor_parallel,
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
ViTBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
attention_dropout=attention_dropout,
dropout=dropout,
drop_path=dpr[i],
activation=activation,
dtype=dtype,
bias=bias,
checkpoint=checkpoint,
init_method=init_method,
tensor_parallel=tensor_parallel,
) for i in range(depth)
]
norm = col_nn.LayerNorm(
normalized_shape=dim,
eps=1e-6,
dtype=dtype,
tensor_parallel=tensor_parallel,
)
head = ViTHead(
dim=dim,
num_classes=num_classes,
representation_size=representation_size,
dtype=dtype,
bias=bias,
init_method=init_method,
tensor_parallel=tensor_parallel,
)
self.layers = nn.Sequential(
embed,
*blocks,
norm,
head,
)
def forward(self, x):
x = self.layers(x)
return x
def _create_vit_model(**model_kwargs):
model = VisionTransformer(**model_kwargs)
return model
@MODELS.register_module
def vit_lite_depth7_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, dim=256, depth=7, num_heads=4, mlp_ratio=2, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch4_32(**kwargs):
model_kwargs = dict(img_size=32, patch_size=4, dim=512, depth=6, num_heads=8, mlp_ratio=1, num_classes=10, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_tiny_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_small_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_base_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch16_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch16_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch32_224(**kwargs):
model_kwargs = dict(img_size=224, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)
@MODELS.register_module
def vit_large_patch32_384(**kwargs):
model_kwargs = dict(img_size=384, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
return _create_vit_model(**model_kwargs)

View File

@ -0,0 +1,74 @@
import time
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import all_gather, all_reduce, reduce_scatter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import get_current_device
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
SIZE = 8
def check_all_gather():
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
tensor = tensor.to(get_current_device())
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
op.wait()
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
torch.cuda.synchronize()
def check_reduce_scatter():
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
tensor = tensor.to(get_current_device())
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
op.wait()
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
torch.cuda.synchronize()
def check_all_reduce():
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
tensor = tensor.to(get_current_device())
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
op.wait()
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
torch.cuda.synchronize()
def check_layer(rank, world_size):
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl')
assert dist.get_rank() == gpc.get_global_rank()
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
check_all_gather()
check_reduce_scatter()
check_all_reduce()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_comm():
world_size = 4
run_func = partial(check_layer, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_comm()

View File

@ -1,141 +0,0 @@
import pytest
from pathlib import Path
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.context.parallel_mode import ParallelMode
from colossalai.logging import get_dist_logger
import colossalai
import torch
import os
from colossalai.builder import build_pipeline_model_from_cfg
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.loss import CrossEntropyLoss2D
from colossalai.trainer.metric import Accuracy2D
from colossalai.trainer import metric, hooks, Trainer
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from colossalai.engine.schedule import PipelineSchedule
from torchvision import transforms
from torchvision.datasets import CIFAR10
from colossalai.nn import LinearWarmupLR
from tqdm import tqdm
import vit_t_2d
BATCH_SIZE = 16
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(
parallel=dict(
pipeline=2,
tensor=dict(size=4, mode='2d')
),
fp16=dict(
mode=AMP_TYPE.TORCH
),
gradient_accumulation=2
)
@pytest.mark.dist
@pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
def test_hybrid_parallel():
parser = colossalai.get_default_parser()
args = parser.parse_args()
colossalai.launch_from_slurm(config=CONFIG,
host=args.host,
port=29500)
logger = get_dist_logger()
# if gpc.get_global_rank() == 0:
# logger.log_to_file('./logs/cifar10_2d_vit',
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
# build vit-t-32
model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1)
# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010]),
]
)
)
test_dataset = CIFAR10(
root=Path(os.environ['DATA']),
train=False,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010]),
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
)
test_dataloader = get_dataloader(dataset=test_dataset,
add_sampler=True,
batch_size=BATCH_SIZE,
num_workers=1,
pin_memory=True,
)
# build criterion
criterion = CrossEntropyLoss2D()
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
# lr_scheduler
steps_per_epoch = GradAccumLrSchedulerByStep.compute_effective_steps_per_epoch(train_dataloader, accumulate_size=2)
total_steps = steps_per_epoch * NUM_EPOCHS
warmup_steps = steps_per_epoch * WARMUP_EPOCHS
lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
model, optimizer, criterion, train_dataloader, test_dataloader, lr_scheduler)
timer = MultiTimer()
schedule = PipelineSchedule(num_microbatches=4)
trainer = Trainer(
engine=engine,
timer=timer,
logger=logger,
schedule=schedule
)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.Accuracy2DHook(),
hooks.LogMetricByEpochHook(logger),
]
trainer.fit(
train_dataloader=train_dataloader,
epochs=NUM_EPOCHS,
test_dataloader=test_dataloader,
test_interval=1,
hooks=hook_list,
display_progress=True
)
if __name__ == '__main__':
test_hybrid_parallel()

View File

@ -1,3 +0,0 @@
#!/usr/bin/env sh
python run_cifar10_vit2d_with_pipeline.py --host $HOST

View File

@ -0,0 +1,103 @@
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp.amp_type import AMP_TYPE
from colossalai.builder import build_pipeline_model
from colossalai.engine.schedule import PipelineSchedule
from colossalai.logging import get_dist_logger
from colossalai.nn import Accuracy, LinearWarmupLR
from colossalai.nn.loss import CrossEntropyLoss
from colossalai.trainer import Trainer, hooks
from colossalai.utils import MultiTimer, get_dataloader
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
from model_zoo.vit import vit_tiny_patch4_32
from torchvision import transforms
from torchvision.datasets import CIFAR10
BATCH_SIZE = 16
NUM_EPOCHS = 60
WARMUP_EPOCHS = 5
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
fp16=dict(mode=AMP_TYPE.TORCH),
gradient_accumulation=2)
def run_trainer(rank, world_size):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl')
logger = get_dist_logger()
model = vit_tiny_patch4_32(tensor_parallel='1d')
pipe_model = build_pipeline_model(model.layers, num_chunks=1)
# build dataloaders
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_train)
test_dataset = CIFAR10(root=Path(os.environ['DATA']), train=False, transform=transform_test)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
# build criterion
criterion = CrossEntropyLoss(tensor_parallel='1d')
# optimizer
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
# lr_scheduler
steps_per_epoch = GradAccumLrSchedulerByStep.compute_effective_steps_per_epoch(train_dataloader, accumulate_size=2)
total_steps = steps_per_epoch * NUM_EPOCHS
warmup_steps = steps_per_epoch * WARMUP_EPOCHS
lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(pipe_model, optimizer, criterion,
train_dataloader, test_dataloader,
lr_scheduler)
timer = MultiTimer()
schedule = PipelineSchedule(num_microbatches=4)
trainer = Trainer(engine=engine, timer=timer, logger=logger, schedule=schedule)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')),
hooks.LogMetricByEpochHook(logger),
]
trainer.fit(train_dataloader=train_dataloader,
epochs=NUM_EPOCHS,
max_steps=5,
test_dataloader=test_dataloader,
test_interval=1,
hooks=hook_list,
display_progress=True)
@pytest.mark.dist
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
def test_hybrid_parallel():
world_size = 8
run_func = partial(run_trainer, world_size=world_size)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_hybrid_parallel()

View File

@ -1,74 +0,0 @@
import sys
from pathlib import Path
repo_path = str(Path(__file__).absolute().parents[2])
sys.path.append(repo_path)
try:
import model_zoo.vit.vision_transformer_from_config
except ImportError:
raise ImportError("model_zoo is not found, please check your path")
IMG_SIZE = 32
PATCH_SIZE = 4
DIM = 512
NUM_ATTENTION_HEADS = 8
NUM_CLASSES = 10
DEPTH = 6
model_cfg = dict(
type='VisionTransformerFromConfig',
tensor_splitting_cfg=dict(
type='ViTInputSplitter2D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)

View File

@ -1,40 +0,0 @@
import os
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
# resnet 18
model = dict(type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=10)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
train_data = dict(dataset=dict(type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
download=True,
transform_pipeline=[
dict(type='Resize',
size=(IMG_SIZE, IMG_SIZE)),
dict(type='ToTensor'),
dict(type='Normalize',
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))
]),
dataloader=dict(batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
drop_last=True))
optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')

View File

@ -1,16 +0,0 @@
import os
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
fp16 = dict(mode=AMP_TYPE.APEX)

View File

@ -1,42 +0,0 @@
import os
from pathlib import Path
from colossalai.engine import AMP_TYPE
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
# resnet 18
model = dict(type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=10)
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
train_data = dict(dataset=dict(type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
download=True,
transform_pipeline=[
dict(type='Resize',
size=(IMG_SIZE, IMG_SIZE)),
dict(type='ToTensor'),
dict(type='Normalize',
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))
]),
dataloader=dict(batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
drop_last=True))
optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
fp16 = dict(mode=AMP_TYPE.TORCH)

View File

@ -1,46 +0,0 @@
import os
from pathlib import Path
BATCH_SIZE = 128
IMG_SIZE = 224
DIM = 768
NUM_CLASSES = 10
NUM_ATTN_HEADS = 12
# resnet 18
model = dict(type='VanillaResNet',
block_type='ResNetBasicBlock',
layers=[2, 2, 2, 2],
num_cls=10)
train_data = dict(dataset=dict(type='CIFAR10Dataset',
root=Path(os.environ['DATA']),
download=True,
transform_pipeline=[
dict(type='Resize',
size=(IMG_SIZE, IMG_SIZE)),
dict(type='ToTensor'),
dict(type='Normalize',
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))
]),
dataloader=dict(batch_size=BATCH_SIZE,
pin_memory=True,
num_workers=4,
drop_last=True))
optimizer = dict(type='Adam', lr=0.001)
loss = dict(type='CrossEntropyLoss')
parallel = dict(
pipeline=dict(size=4),
tensor=dict(size=1, mode=None)
)
engine = dict(
schedule=dict(
num_microbatches=4
)
)
num_epochs = 10

View File

@ -4,7 +4,7 @@ from torch.nn import Parameter
import time
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D
from colossalai.nn import Linear1D_Col, Linear1D_Row
from colossalai.utils import get_current_device, print_rank_0
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
@ -17,7 +17,7 @@ def check_linear_col():
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE, gather_output=True)
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
@ -50,18 +50,20 @@ def check_linear_col():
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = C_master.clone()
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('linear_col gather_output forward: pass')
print_rank_0('linear_col forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.detach()
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
C_master.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
@ -73,7 +75,7 @@ def check_linear_col():
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_col gather_output backward: pass')
print_rank_0('linear_col backward: pass')
def check_linear_row():
@ -84,12 +86,13 @@ def check_linear_row():
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, parallel_input=False)
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
@ -119,16 +122,18 @@ def check_linear_row():
C = C_master.clone()
check_equal(out, C)
print_rank_0('linear_row no parallel_input forward: pass')
print_rank_0('linear_row forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.detach()
grad = grad_master.clone()
out.backward(grad)
C_master.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
@ -138,276 +143,4 @@ def check_linear_row():
B_grad = B_master.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_row no parallel_input backward: pass')
class Testvithead(torch.nn.Module):
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0]
x = self.linear(x)
return x
def check_head():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
head = ViTHead1D(INPUT_SIZE, NUM_CLASSES, dtype=dtype)
torch.nn.init.zeros_(head.linear.bias)
torch.nn.init.ones_(head.linear.weight)
head = head.to(device)
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
torch.nn.init.zeros_(layer.linear.bias)
torch.nn.init.ones_(layer.linear.weight)
layer = layer.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
fwd_start = time.time()
out = head(A)
fwd_end = time.time()
print_rank_0(
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer(A_master)
# C = torch.chunk(C_master, DEPTH, dim=0)[i]
print_rank_0('Rank {} head forward: {}'.format(i, check_equal(out, C_master)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
# grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# bwd_start = time.time()
out.backward(grad_master)
# bwd_end = time.time()
# print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
# logger)
C_master.backward(grad_master)
A_grad = A_master.grad
# if j == 0:
print_rank_0('Rank {} head backward (input_grad): {}'.format(
i, check_equal(A_grad, A.grad)))
class Testvitembed(torch.nn.Module):
def __init__(self, img_size: int, patch_size: int, in_chans: int,
embed_size: int, drop_prob: float) -> None:
super().__init__()
self.proj = torch.nn.Conv2d(in_chans,
embed_size,
kernel_size=patch_size,
stride=patch_size)
num_patches = (img_size // patch_size)**2
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
self.pos_embed = torch.nn.Parameter(
torch.zeros(1, num_patches + 1, embed_size))
self.pos_drop = torch.nn.Dropout(drop_prob)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTPatchEmbedding1D(IMG_SIZE, 4, HIDDEN_SIZE)
layer2 = ViTTokenFuser1D(IMG_SIZE, 4, HIDDEN_SIZE)
torch.nn.init.zeros_(layer.proj.bias)
torch.nn.init.ones_(layer.proj.weight)
torch.nn.init.ones_(layer2.cls_token)
torch.nn.init.ones_(layer2.pos_embed)
layer = layer.to(device)
layer2 = layer2.to(device)
layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
torch.nn.init.zeros_(layer_master.proj.bias)
torch.nn.init.ones_(layer_master.proj.weight)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer2(layer(A))
fwd_end = time.time()
print_rank_0(
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
# out_cls = out[:, 0]
# out_tensor = out[:, 1:]
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
# if j == 0:
# C_cls = C_master[:, 0]
# C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
# C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
# logger.info('Rank {} embed forward (cls): {}'.format(
# rank, check_equal(out_cls, C_cls)))
# C = C_master[:, 1:]
print_rank_0('Rank {} embed forward: {}'.format(i, check_equal(out, C_master)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape,
dtype=dtype,
device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
# cls_grad = grad_master[:, 0]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
# grad = grad_master[:, 1:]
# grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
bwd_start = time.time()
out.backward(grad_master)
bwd_end = time.time()
print_rank_0(
'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start))
C_master.backward(grad_master)
A_grad = A_master.grad
print_rank_0('Rank {} embed backward (input_grad): {}'.format(i, check_equal(A_grad, A.grad)))
print_rank_0('Rank {} embed backward (cls_grad): {}'.format(
i, check_equal(layer_master.cls_token.grad, layer2.cls_token.grad)))
print_rank_0('Rank {} embed backward (pos_embed_grad): {}'.format(
i, check_equal(layer_master.pos_embed.grad, layer2.pos_embed.grad)))
print_rank_0('Rank {} embed backward (proj_weight_grad): {}'.format(
i, check_equal(layer_master.proj.weight.grad, layer.proj.weight.grad)))
print_rank_0('Rank {} embed backward (proj_bias_grad): {}'.format(
i, check_equal(layer_master.proj.bias.grad, layer.proj.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTSelfAttention1D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
0.5,
0.5
).to(device=device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A)
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
def check_mlp():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTMLP1D(
HIDDEN_SIZE,
4.0
).to(device=device)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_patch_embedding():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = 4
PATCH_SIZE = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = ViTPatchEmbedding1D(
INPUT_SIZE,
PATCH_SIZE,
HIDDEN_SIZE,
).to(device=device)
A_shape = (BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
print('output size: ', out.size())
assert out.shape == (BATCH_SIZE, 4, HIDDEN_SIZE)
print_rank_0('patch embedding forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('patch embedding backward: pass')
print_rank_0('linear_row backward: pass')

View File

@ -3,12 +3,12 @@
import torch
DEPTH = 2
DEPTH = 4
BATCH_SIZE = 8
SEQ_LENGTH = 8
IMG_SIZE = 16
HIDDEN_SIZE = 8
NUM_CLASSES = 10
NUM_CLASSES = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True

Some files were not shown because too many files have changed in this diff Show More