mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
Layer integration (#83)
* integrated parallel layers for ease of building models * integrated 2.5d layers * cleaned codes and unit tests * added log metric by step hook; updated imagenet benchmark; fixed some bugs * reworked initialization; cleaned codes Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
parent
5c3843dc98
commit
0fedef4f3c
66
benchmark/README.md
Normal file
66
benchmark/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Benchmark for Tuning Accuracy and Efficiency
|
||||
|
||||
## Overview
|
||||
|
||||
The benchmark includes our efforts in using Colossal-AI to train different tasks to achieve SOTA results.
|
||||
We are interested in both validataion accuracy and training speed, and prefer larger batch size to take advantage of more GPU devices.
|
||||
For example, we trained vision transformer with batch size 512 on CIFAR10 and 4096 on ImageNet1k, which are basically not used in existing works.
|
||||
Some of the results in the benchmark trained with 8x A100 are shown below.
|
||||
|
||||
| Task | Model | Training Time | Top-1 Accuracy |
|
||||
| ---------- | ------------ | ------------- | -------------- |
|
||||
| CIFAR10 | [ViT-Lite-7/4](https://arxiv.org/pdf/2104.05704.pdf) | ~ 16 min | ~ 90.5% |
|
||||
| ImageNet1k | ViT-S/16 | ~ 16.5 h | ~ 74.5% |
|
||||
|
||||
The `train.py` script in each task runs training with the specific configuration script in `configs/` for different parallelisms.
|
||||
Supported parallelisms include data parallel only (ends with `vanilla`), 1D (ends with `1d`), 2D (ends with `2d`), 2.5D (ends with `2p5d`), 3D (ends with `3d`).
|
||||
|
||||
Each configuration scripts basically includes the following elements, taking ImageNet1k task as example:
|
||||
```
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
# data parallel only
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = None
|
||||
|
||||
# parallelism setting
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, ) # amp setting
|
||||
|
||||
gradient_accumulation = 2 # accumulate 2 steps for gradient update
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation # actual batch size for dataloader
|
||||
|
||||
clip_grad_norm = 1.0 # clip gradient with norm 1.0
|
||||
```
|
||||
Upper case elements are basically what `train.py` needs, and lower case elements are what Colossal-AI needs to initialize the training.
|
||||
|
||||
## Usage
|
||||
|
||||
To start training, use the following command to run each worker:
|
||||
```
|
||||
$ DATA=/path/to/dataset python train.py --world_size=WORLD_SIZE \
|
||||
--rank=RANK \
|
||||
--local_rank=LOCAL_RANK \
|
||||
--host=MASTER_IP_ADDRESS \
|
||||
--port=MASTER_PORT \
|
||||
--config=CONFIG_FILE
|
||||
```
|
||||
It is also recommended to start training with `torchrun` as:
|
||||
```
|
||||
$ DATA=/path/to/dataset torchrun --nproc_per_node=NUM_GPUS_PER_NODE \
|
||||
--nnodes=NUM_NODES \
|
||||
--node_rank=NODE_RANK \
|
||||
--master_addr=MASTER_IP_ADDRESS \
|
||||
--master_port=MASTER_PORT \
|
||||
train.py --config=CONFIG_FILE
|
||||
```
|
18
benchmark/cifar/configs/vit_1d.py
Normal file
18
benchmark/cifar/configs/vit_1d.py
Normal file
@ -0,0 +1,18 @@
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
seed = 42
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
|
18
benchmark/cifar/configs/vit_2d.py
Normal file
18
benchmark/cifar/configs/vit_2d.py
Normal file
@ -0,0 +1,18 @@
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '2d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
seed = 42
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
|
19
benchmark/cifar/configs/vit_2p5d.py
Normal file
19
benchmark/cifar/configs/vit_2p5d.py
Normal file
@ -0,0 +1,19 @@
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
DEPTH = 1
|
||||
TENSOR_PARALLEL_MODE = '2.5d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
|
||||
)
|
||||
|
||||
seed = 42
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
|
18
benchmark/cifar/configs/vit_3d.py
Normal file
18
benchmark/cifar/configs/vit_3d.py
Normal file
@ -0,0 +1,18 @@
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 8
|
||||
TENSOR_PARALLEL_MODE = '3d'
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
seed = 42
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
|
18
benchmark/cifar/configs/vit_vanilla.py
Normal file
18
benchmark/cifar/configs/vit_vanilla.py
Normal file
@ -0,0 +1,18 @@
|
||||
BATCH_SIZE = 512
|
||||
LEARNING_RATE = 2e-3
|
||||
WEIGHT_DECAY = 3e-2
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = None
|
||||
|
||||
NUM_EPOCHS = 200
|
||||
WARMUP_EPOCHS = 40
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
seed = 42
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_cifar10_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}/"
|
126
benchmark/cifar/train.py
Normal file
126
benchmark/cifar/train.py
Normal file
@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torchvision
|
||||
from colossalai.builder import *
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook,
|
||||
LogMetricByEpochHook,
|
||||
LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook,
|
||||
LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from model_zoo.vit import vit_lite_depth7_patch4_32
|
||||
from torchvision import transforms
|
||||
|
||||
DATASET_PATH = str(os.environ['DATA'])
|
||||
|
||||
|
||||
def build_cifar(batch_size):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=transform_train)
|
||||
test_dataset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=batch_size,
|
||||
num_workers=4,
|
||||
pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, num_workers=4, pin_memory=True)
|
||||
return train_dataloader, test_dataloader
|
||||
|
||||
|
||||
def train_cifar():
|
||||
args = colossalai.get_default_parser().parse_args()
|
||||
# standard launch
|
||||
# colossalai.launch(config=args.config,
|
||||
# rank=args.rank,
|
||||
# world_size=args.world_size,
|
||||
# local_rank=args.local_rank,
|
||||
# host=args.host,
|
||||
# port=args.port)
|
||||
|
||||
# launch from torchrun
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
logger = get_dist_logger()
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_lite_depth7_patch4_32(tensor_parallel=tp)
|
||||
|
||||
train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
steps_per_epoch = len(train_dataloader)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS * steps_per_epoch,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS * steps_per_epoch)
|
||||
|
||||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_cifar()
|
26
benchmark/imagenet100/configs/vit_1d.py
Normal file
26
benchmark/imagenet100/configs/vit_1d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet100/configs/vit_2d.py
Normal file
26
benchmark/imagenet100/configs/vit_2d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '2d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
27
benchmark/imagenet100/configs/vit_2p5d.py
Normal file
27
benchmark/imagenet100/configs/vit_2p5d.py
Normal file
@ -0,0 +1,27 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
DEPTH = 1
|
||||
TENSOR_PARALLEL_MODE = '2.5d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet100/configs/vit_3d.py
Normal file
26
benchmark/imagenet100/configs/vit_3d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 8
|
||||
TENSOR_PARALLEL_MODE = '3d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet100/configs/vit_vanilla.py
Normal file
26
benchmark/imagenet100/configs/vit_vanilla.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = None
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet100_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
211
benchmark/imagenet100/train.py
Normal file
211
benchmark/imagenet100/train.py
Normal file
@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import colossalai
|
||||
import nvidia.dali.fn as fn
|
||||
import nvidia.dali.tfrecord as tfrec
|
||||
import torch
|
||||
from colossalai.builder import *
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.utils import MultiTimer
|
||||
from model_zoo.vit import vit_small_patch16_224
|
||||
from nvidia.dali import types
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
|
||||
|
||||
DATASET_PATH = str(os.environ['DATA'])
|
||||
|
||||
TRAIN_RECS = DATASET_PATH + '/train/*'
|
||||
VAL_RECS = DATASET_PATH + '/validation/*'
|
||||
TRAIN_IDX = DATASET_PATH + '/idx_files/train/*'
|
||||
VAL_IDX = DATASET_PATH + '/idx_files/validation/*'
|
||||
|
||||
|
||||
class DaliDataloader(DALIClassificationIterator):
|
||||
def __init__(self,
|
||||
tfrec_filenames,
|
||||
tfrec_idx_filenames,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
batch_size=128,
|
||||
num_threads=4,
|
||||
resize=256,
|
||||
crop=224,
|
||||
prefetch=2,
|
||||
training=True,
|
||||
gpu_aug=False,
|
||||
cuda=True):
|
||||
pipe = Pipeline(batch_size=batch_size,
|
||||
num_threads=num_threads,
|
||||
device_id=torch.cuda.current_device() if cuda else None,
|
||||
seed=1024)
|
||||
with pipe:
|
||||
inputs = fn.readers.tfrecord(path=tfrec_filenames,
|
||||
index_path=tfrec_idx_filenames,
|
||||
random_shuffle=training,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
initial_fill=10000,
|
||||
read_ahead=True,
|
||||
prefetch_queue_depth=prefetch,
|
||||
name='Reader',
|
||||
features={
|
||||
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
|
||||
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||
})
|
||||
images = inputs["image/encoded"]
|
||||
|
||||
if training:
|
||||
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
|
||||
images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu')
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
else:
|
||||
# decode jpeg and resize
|
||||
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
|
||||
images = fn.resize(images,
|
||||
device='gpu' if gpu_aug else 'cpu',
|
||||
resize_x=resize,
|
||||
resize_y=resize,
|
||||
dtype=types.FLOAT,
|
||||
interp_type=types.INTERP_TRIANGULAR)
|
||||
flip_lr = False
|
||||
|
||||
# center crop and normalise
|
||||
images = fn.crop_mirror_normalize(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop),
|
||||
mean=[127.5],
|
||||
std=[127.5],
|
||||
mirror=flip_lr)
|
||||
label = inputs["image/class/label"] - 1 # 0-999
|
||||
# LSG: element_extract will raise exception, let's flatten outside
|
||||
# label = fn.element_extract(label, element_map=0) # Flatten
|
||||
if cuda: # transfer data to gpu
|
||||
pipe.set_outputs(images.gpu(), label.gpu())
|
||||
else:
|
||||
pipe.set_outputs(images, label)
|
||||
|
||||
pipe.build()
|
||||
last_batch_policy = 'DROP' if training else 'PARTIAL'
|
||||
super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
|
||||
|
||||
def __iter__(self):
|
||||
# if not reset (after an epoch), reset; if just initialize, ignore
|
||||
if self._counter >= self._size or self._size < 0:
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
data = super().__next__()
|
||||
img, label = data[0]['data'], data[0]['label']
|
||||
label = label.squeeze()
|
||||
return (img, ), (label, )
|
||||
|
||||
|
||||
def build_dali_train(batch_size):
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(TRAIN_RECS)),
|
||||
sorted(glob.glob(TRAIN_IDX)),
|
||||
batch_size=batch_size,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=True,
|
||||
gpu_aug=True,
|
||||
cuda=True,
|
||||
)
|
||||
|
||||
|
||||
def build_dali_test(batch_size):
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(VAL_RECS)),
|
||||
sorted(glob.glob(VAL_IDX)),
|
||||
batch_size=batch_size,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=False,
|
||||
gpu_aug=True,
|
||||
cuda=True,
|
||||
)
|
||||
|
||||
|
||||
def train_imagenet():
|
||||
args = colossalai.get_default_parser().parse_args()
|
||||
# standard launch
|
||||
# colossalai.launch(config=args.config,
|
||||
# rank=args.rank,
|
||||
# world_size=args.world_size,
|
||||
# local_rank=args.local_rank,
|
||||
# host=args.host,
|
||||
# port=args.port)
|
||||
|
||||
# launch from torchrun
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
logger = get_dist_logger()
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=100, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_imagenet()
|
26
benchmark/imagenet1k/configs/vit_1d.py
Normal file
26
benchmark/imagenet1k/configs/vit_1d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '1d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet1k/configs/vit_2d.py
Normal file
26
benchmark/imagenet1k/configs/vit_2d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
TENSOR_PARALLEL_MODE = '2d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
27
benchmark/imagenet1k/configs/vit_2p5d.py
Normal file
27
benchmark/imagenet1k/configs/vit_2p5d.py
Normal file
@ -0,0 +1,27 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 4
|
||||
DEPTH = 1
|
||||
TENSOR_PARALLEL_MODE = '2.5d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE, depth=DEPTH),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet1k/configs/vit_3d.py
Normal file
26
benchmark/imagenet1k/configs/vit_3d.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 8
|
||||
TENSOR_PARALLEL_MODE = '3d'
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
26
benchmark/imagenet1k/configs/vit_vanilla.py
Normal file
26
benchmark/imagenet1k/configs/vit_vanilla.py
Normal file
@ -0,0 +1,26 @@
|
||||
from colossalai.amp import AMP_TYPE
|
||||
|
||||
TOTAL_BATCH_SIZE = 4096
|
||||
LEARNING_RATE = 3e-3
|
||||
WEIGHT_DECAY = 0.3
|
||||
|
||||
TENSOR_PARALLEL_SIZE = 1
|
||||
TENSOR_PARALLEL_MODE = None
|
||||
|
||||
NUM_EPOCHS = 300
|
||||
WARMUP_EPOCHS = 32
|
||||
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
|
||||
)
|
||||
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH, )
|
||||
|
||||
gradient_accumulation = 2
|
||||
|
||||
BATCH_SIZE = TOTAL_BATCH_SIZE // gradient_accumulation
|
||||
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
LOG_PATH = f"./vit_{TENSOR_PARALLEL_MODE}_imagenet1k_tp{TENSOR_PARALLEL_SIZE}_bs{BATCH_SIZE}_lr{LEARNING_RATE}_{fp16['mode']}_clip_grad{clip_grad_norm}/"
|
211
benchmark/imagenet1k/train.py
Normal file
211
benchmark/imagenet1k/train.py
Normal file
@ -0,0 +1,211 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import colossalai
|
||||
import nvidia.dali.fn as fn
|
||||
import nvidia.dali.tfrecord as tfrec
|
||||
import torch
|
||||
from colossalai.builder import *
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.trainer.hooks import (AccuracyHook, LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook,
|
||||
LogTimingByEpochHook, LossHook, LRSchedulerHook, ThroughputHook)
|
||||
from colossalai.utils import MultiTimer
|
||||
from model_zoo.vit import vit_small_patch16_224
|
||||
from nvidia.dali import types
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
|
||||
|
||||
DATASET_PATH = str(os.environ['DATA'])
|
||||
|
||||
TRAIN_RECS = DATASET_PATH + '/train/*'
|
||||
VAL_RECS = DATASET_PATH + '/validation/*'
|
||||
TRAIN_IDX = DATASET_PATH + '/idx_files/train/*'
|
||||
VAL_IDX = DATASET_PATH + '/idx_files/validation/*'
|
||||
|
||||
|
||||
class DaliDataloader(DALIClassificationIterator):
|
||||
def __init__(self,
|
||||
tfrec_filenames,
|
||||
tfrec_idx_filenames,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
batch_size=128,
|
||||
num_threads=4,
|
||||
resize=256,
|
||||
crop=224,
|
||||
prefetch=2,
|
||||
training=True,
|
||||
gpu_aug=False,
|
||||
cuda=True):
|
||||
pipe = Pipeline(batch_size=batch_size,
|
||||
num_threads=num_threads,
|
||||
device_id=torch.cuda.current_device() if cuda else None,
|
||||
seed=1024)
|
||||
with pipe:
|
||||
inputs = fn.readers.tfrecord(path=tfrec_filenames,
|
||||
index_path=tfrec_idx_filenames,
|
||||
random_shuffle=training,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
initial_fill=10000,
|
||||
read_ahead=True,
|
||||
prefetch_queue_depth=prefetch,
|
||||
name='Reader',
|
||||
features={
|
||||
'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
|
||||
'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64, -1),
|
||||
})
|
||||
images = inputs["image/encoded"]
|
||||
|
||||
if training:
|
||||
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
|
||||
images = fn.random_resized_crop(images, size=crop, device='gpu' if gpu_aug else 'cpu')
|
||||
flip_lr = fn.random.coin_flip(probability=0.5)
|
||||
else:
|
||||
# decode jpeg and resize
|
||||
images = fn.decoders.image(images, device='mixed' if gpu_aug else 'cpu', output_type=types.RGB)
|
||||
images = fn.resize(images,
|
||||
device='gpu' if gpu_aug else 'cpu',
|
||||
resize_x=resize,
|
||||
resize_y=resize,
|
||||
dtype=types.FLOAT,
|
||||
interp_type=types.INTERP_TRIANGULAR)
|
||||
flip_lr = False
|
||||
|
||||
# center crop and normalise
|
||||
images = fn.crop_mirror_normalize(images,
|
||||
dtype=types.FLOAT,
|
||||
crop=(crop, crop),
|
||||
mean=[127.5],
|
||||
std=[127.5],
|
||||
mirror=flip_lr)
|
||||
label = inputs["image/class/label"] - 1 # 0-999
|
||||
# LSG: element_extract will raise exception, let's flatten outside
|
||||
# label = fn.element_extract(label, element_map=0) # Flatten
|
||||
if cuda: # transfer data to gpu
|
||||
pipe.set_outputs(images.gpu(), label.gpu())
|
||||
else:
|
||||
pipe.set_outputs(images, label)
|
||||
|
||||
pipe.build()
|
||||
last_batch_policy = 'DROP' if training else 'PARTIAL'
|
||||
super().__init__(pipe, reader_name="Reader", auto_reset=True, last_batch_policy=last_batch_policy)
|
||||
|
||||
def __iter__(self):
|
||||
# if not reset (after an epoch), reset; if just initialize, ignore
|
||||
if self._counter >= self._size or self._size < 0:
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
data = super().__next__()
|
||||
img, label = data[0]['data'], data[0]['label']
|
||||
label = label.squeeze()
|
||||
return (img, ), (label, )
|
||||
|
||||
|
||||
def build_dali_train(batch_size):
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(TRAIN_RECS)),
|
||||
sorted(glob.glob(TRAIN_IDX)),
|
||||
batch_size=batch_size,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=True,
|
||||
gpu_aug=True,
|
||||
cuda=True,
|
||||
)
|
||||
|
||||
|
||||
def build_dali_test(batch_size):
|
||||
return DaliDataloader(
|
||||
sorted(glob.glob(VAL_RECS)),
|
||||
sorted(glob.glob(VAL_IDX)),
|
||||
batch_size=batch_size,
|
||||
shard_id=gpc.get_local_rank(ParallelMode.DATA),
|
||||
num_shards=gpc.get_world_size(ParallelMode.DATA),
|
||||
training=False,
|
||||
gpu_aug=True,
|
||||
cuda=True,
|
||||
)
|
||||
|
||||
|
||||
def train_imagenet():
|
||||
args = colossalai.get_default_parser().parse_args()
|
||||
# standard launch
|
||||
# colossalai.launch(config=args.config,
|
||||
# rank=args.rank,
|
||||
# world_size=args.world_size,
|
||||
# local_rank=args.local_rank,
|
||||
# host=args.host,
|
||||
# port=args.port)
|
||||
|
||||
# launch from torchrun
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
logger = get_dist_logger()
|
||||
if hasattr(gpc.config, 'LOG_PATH'):
|
||||
if gpc.get_global_rank() == 0:
|
||||
log_path = gpc.config.LOG_PATH
|
||||
if not os.path.exists(log_path):
|
||||
os.mkdir(log_path)
|
||||
logger.log_to_file(log_path)
|
||||
|
||||
tp = gpc.config.parallel.tensor.mode
|
||||
|
||||
model = vit_small_patch16_224(tensor_parallel=tp, num_classes=1000, init_method='jax')
|
||||
|
||||
train_dataloader = build_dali_train(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
test_dataloader = build_dali_test(gpc.config.BATCH_SIZE // gpc.data_parallel_size)
|
||||
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1, tensor_parallel=tp)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
|
||||
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
|
||||
total_steps=gpc.config.NUM_EPOCHS,
|
||||
warmup_steps=gpc.config.WARMUP_EPOCHS)
|
||||
|
||||
engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader)
|
||||
|
||||
logger.info("Engine is built", ranks=[0])
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
logger.info("Trainer is built", ranks=[0])
|
||||
|
||||
hooks = [
|
||||
LogMetricByEpochHook(logger=logger),
|
||||
LogMetricByStepHook(),
|
||||
# LogTimingByEpochHook(timer=timer, logger=logger),
|
||||
# LogMemoryByEpochHook(logger=logger),
|
||||
AccuracyHook(accuracy_func=Accuracy(tensor_parallel=tp)),
|
||||
LossHook(),
|
||||
ThroughputHook(),
|
||||
LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True)
|
||||
]
|
||||
|
||||
logger.info("Train start", ranks=[0])
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
test_dataloader=test_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
hooks=hooks,
|
||||
display_progress=True,
|
||||
test_interval=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_imagenet()
|
@ -1,14 +1,17 @@
|
||||
from .collective import all_gather, reduce_scatter, all_reduce
|
||||
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
|
||||
send_backward, send_backward_recv_backward, send_forward_recv_backward,
|
||||
send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
|
||||
from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce
|
||||
from .p2p import (send_forward, send_forward_recv_forward,
|
||||
send_backward_recv_forward, send_backward,
|
||||
send_backward_recv_backward, send_forward_recv_backward,
|
||||
send_forward_backward_recv_forward_backward, recv_forward,
|
||||
recv_backward)
|
||||
from .ring import ring_forward
|
||||
from .utils import send_tensor_meta, recv_tensor_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather', 'reduce_scatter', 'all_reduce',
|
||||
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
|
||||
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'all_gather', 'reduce_scatter', 'all_reduce', 'broadcast', 'reduce',
|
||||
'send_forward', 'send_forward_recv_forward',
|
||||
'send_forward_backward_recv_forward_backward', 'send_backward',
|
||||
'send_backward_recv_backward', 'send_backward_recv_forward',
|
||||
'send_forward_recv_backward', 'recv_backward', 'recv_forward',
|
||||
'ring_forward', 'send_tensor_meta', 'recv_tensor_meta'
|
||||
]
|
@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
@ -10,8 +11,7 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||
"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
specific dimension.
|
||||
|
||||
@ -25,29 +25,31 @@ def all_gather(tensor: Tensor, dim: int,
|
||||
:rtype: :class:`torch.Tensor`
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
temp = tensor.clone()
|
||||
# shape = list(temp.shape)
|
||||
# shape[dim] *= depth
|
||||
# out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
|
||||
# out = list(torch.chunk(out, depth, dim=dim))
|
||||
# out = [val.contiguous() for val in out]
|
||||
shape = [1] * len(tensor.shape)
|
||||
shape[dim] = depth
|
||||
out = tensor.repeat(shape)
|
||||
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
|
||||
op = dist.all_gather(tensor_list=out,
|
||||
tensor=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
# out = torch.cat(out, dim=dim)
|
||||
if depth == 1:
|
||||
out = [tensor]
|
||||
work = None
|
||||
else:
|
||||
shape = list(tensor.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
|
||||
temp = list(torch.chunk(out, depth, dim=0))
|
||||
work = dist.all_gather(tensor_list=temp,
|
||||
tensor=tensor.transpose(0, dim).contiguous(),
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
out = torch.transpose(out, 0, dim)
|
||||
if async_op:
|
||||
return out, op
|
||||
return out, work
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def reduce_scatter(tensor: Tensor, dim: int,
|
||||
parallel_mode: ParallelMode, async_op=False) -> Tensor:
|
||||
def reduce_scatter(tensor: Tensor,
|
||||
dim: int,
|
||||
parallel_mode: ParallelMode,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
async_op: bool = False) -> Tensor:
|
||||
"""Reduces all tensors then scatters it in a specific dimension to all
|
||||
members in the parallel group.
|
||||
|
||||
@ -61,52 +63,57 @@ def reduce_scatter(tensor: Tensor, dim: int,
|
||||
:rtype: :class:`Tensor`
|
||||
"""
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
# temp = list(torch.chunk(tensor, depth, dim=dim))
|
||||
# temp = [val.contiguous() for val in temp]
|
||||
# out = torch.zeros(temp[0].shape,
|
||||
# dtype=temp[0].dtype,
|
||||
# device=get_current_device())
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = temp[0].clone()
|
||||
op = dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if depth == 1:
|
||||
out = tensor
|
||||
work = None
|
||||
else:
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
|
||||
work = dist.reduce_scatter(output=out,
|
||||
input_list=temp,
|
||||
op=op,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
if async_op:
|
||||
return out, op
|
||||
return out, work
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
def all_reduce(tensor: Tensor,
|
||||
parallel_mode: ParallelMode,
|
||||
async_op=False) -> Tensor:
|
||||
op = dist.all_reduce(tensor,
|
||||
group=gpc.get_group(parallel_mode),
|
||||
async_op=async_op)
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
async_op: bool = False) -> Tensor:
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
if depth == 1:
|
||||
work = None
|
||||
else:
|
||||
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
if async_op:
|
||||
return tensor, op
|
||||
return tensor, work
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
# def scatter(tensor: Tensor, src: int, dim: int,
|
||||
# parallel_mode: ParallelMode) -> Tensor:
|
||||
# """Scatters in a specific dimension from source rank to all ranks in
|
||||
# the parallel group.
|
||||
|
||||
# :param tensor: Tensor to be scattered
|
||||
# :param dim: The dimension scattering in
|
||||
# :param parallel_mode: Parallel group mode used in this communication
|
||||
# :type tensor: Tensor
|
||||
# :type dim: int
|
||||
# :type parallel_mode: ParallelMode
|
||||
# :return: The tensor generated by scatter
|
||||
# :rtype: Tensor
|
||||
# """
|
||||
# depth = gpc.get_world_size(parallel_mode)
|
||||
# temp = tensor.clone()
|
||||
# dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
|
||||
# rank = gpc.get_local_rank(parallel_mode)
|
||||
# out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
|
||||
# return out
|
||||
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
if depth == 1:
|
||||
work = None
|
||||
else:
|
||||
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
if async_op:
|
||||
return tensor, work
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
if depth == 1:
|
||||
work = None
|
||||
else:
|
||||
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
|
||||
if async_op:
|
||||
return tensor, work
|
||||
else:
|
||||
return tensor
|
||||
|
@ -497,8 +497,7 @@ class ParallelContext:
|
||||
self._logger.info(
|
||||
f"initialized seed on rank {global_rank}, "
|
||||
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
||||
f"the default parallel seed is {ParallelMode.DATA}.",
|
||||
ranks=[0])
|
||||
f"the default parallel seed is {ParallelMode.DATA}.")
|
||||
else:
|
||||
if self._verbose:
|
||||
self._logger.info(
|
||||
|
@ -184,8 +184,6 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
|
||||
|
||||
|
||||
def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = 'nccl',
|
||||
seed: int = 1024,
|
||||
verbose: bool = True):
|
||||
@ -206,6 +204,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
launch(config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .lr_scheduler import *
|
||||
from .metric import *
|
||||
from .model import *
|
||||
from .optimizer import *
|
||||
|
@ -1,33 +1,140 @@
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init as init
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_weight_(tensor: Tensor, fan_in: int, fan_out: int = None, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(tensor, -a, a)
|
||||
elif init_method == 'jax_embed':
|
||||
def zeros_():
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.zeros_(tensor)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def ones_():
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.ones_(tensor)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def uniform_(a: float = 0., b: float = 1.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.uniform_(tensor, a, b)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def normal_(mean: float = 0., std: float = 1.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.normal_(tensor, mean, std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
return nn.init.trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
return tensor
|
||||
|
||||
if mode == 'fan_in':
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
fan = fan_in
|
||||
elif mode == 'fan_out':
|
||||
assert fan_out is not None, 'Fan_out is not provided.'
|
||||
fan = fan_out
|
||||
else:
|
||||
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
||||
|
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
||||
bound = math.sqrt(3.) * std
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
if 0 in tensor.shape:
|
||||
warnings.warn("Initializing zero-element tensors is a no-op")
|
||||
return tensor
|
||||
|
||||
if mode == 'fan_in':
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
fan = fan_in
|
||||
elif mode == 'fan_out':
|
||||
assert fan_out is not None, 'Fan_out is not provided.'
|
||||
fan = fan_out
|
||||
else:
|
||||
raise ValueError(f'Invalid initialization mode \'{mode}\'')
|
||||
|
||||
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan)
|
||||
return nn.init.normal_(tensor, 0, std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
fan = fan_in
|
||||
if fan_out is not None:
|
||||
fan += fan_out
|
||||
|
||||
std = gain * math.sqrt(scale / float(fan))
|
||||
bound = a * std
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def xavier_normal_(scale: float = 2., gain: float = 1.):
|
||||
# adapted from torch.nn.init
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
fan = fan_in
|
||||
if fan_out is not None:
|
||||
fan += fan_out
|
||||
|
||||
std = gain * math.sqrt(scale / float(fan))
|
||||
|
||||
return nn.init.normal_(tensor, 0., std)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def lecun_uniform_():
|
||||
# adapted from jax.nn.initializers
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
var = 1.0 / fan_in
|
||||
bound = math.sqrt(3 * var)
|
||||
return nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def lecun_normal_():
|
||||
# adapted from jax.nn.initializers
|
||||
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
|
||||
assert fan_in is not None, 'Fan_in is not provided.'
|
||||
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
||||
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
|
||||
|
||||
def init_bias_(tensor: Tensor, fan_in: int, init_method: str = 'torch'):
|
||||
if init_method == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
elif init_method == 'jax':
|
||||
init.normal_(tensor, std=1e-6)
|
||||
elif init_method == 'jax_embed':
|
||||
init.trunc_normal_(tensor, std=.02)
|
||||
elif init_method == 'zero':
|
||||
init.zeros_(tensor)
|
||||
return initializer
|
||||
|
@ -1,8 +1,3 @@
|
||||
from .colossalai_layer import *
|
||||
from .fused_bias_gelu import bias_gelu_impl
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
from .non_parallel_layers import *
|
||||
from .wrapper import *
|
||||
|
@ -1,11 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import collections.abc
|
||||
from itertools import repeat
|
||||
|
||||
import numpy as np
|
||||
from colossalai.utils.common import print_rank_0
|
||||
import torch
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
||||
from colossalai.utils import checkpoint
|
||||
@ -19,8 +18,7 @@ class CheckpointModule(nn.Module):
|
||||
self._use_checkpoint = checkpoint
|
||||
|
||||
def _forward(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'CheckpointModule should implement _forward method instead of origin forward')
|
||||
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._use_checkpoint:
|
||||
@ -36,6 +34,7 @@ class CheckpointModule(nn.Module):
|
||||
self._use_checkpoint = False
|
||||
return super().eval()
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
""" only allow exact division """
|
||||
assert numerator % denominator == 0, \
|
||||
@ -59,7 +58,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
|
@ -1,138 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
def _reduce(input_, parallel_mode):
|
||||
# skip if only one rank involved
|
||||
if gpc.get_world_size(parallel_mode) == 1:
|
||||
return input_
|
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""Pass the input to the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
ctx.mode = parallel_mode
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.mode), None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
return _reduce(input_, parallel_mode)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _split(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _gather(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def reduce_grad(input_, parallel_mode):
|
||||
return _ReduceGrad.apply(input_, parallel_mode)
|
||||
|
||||
|
||||
def reduce_input(input_, parallel_mode):
|
||||
return _ReduceInput.apply(input_, parallel_mode)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
231
colossalai/nn/layer/colossalai_layer.py
Normal file
231
colossalai/nn/layer/colossalai_layer.py
Normal file
@ -0,0 +1,231 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
from torch.nn.modules.activation import *
|
||||
from torch.nn.modules.adaptive import *
|
||||
from torch.nn.modules.batchnorm import *
|
||||
from torch.nn.modules.channelshuffle import *
|
||||
from torch.nn.modules.conv import *
|
||||
from torch.nn.modules.distance import *
|
||||
from torch.nn.modules.dropout import *
|
||||
from torch.nn.modules.flatten import *
|
||||
from torch.nn.modules.fold import *
|
||||
from torch.nn.modules.instancenorm import *
|
||||
from torch.nn.modules.linear import *
|
||||
from torch.nn.modules.normalization import *
|
||||
from torch.nn.modules.padding import *
|
||||
from torch.nn.modules.pixelshuffle import *
|
||||
from torch.nn.modules.pooling import *
|
||||
from torch.nn.modules.rnn import *
|
||||
from torch.nn.modules.sparse import *
|
||||
from torch.nn.modules.transformer import *
|
||||
from torch.nn.modules.upsampling import *
|
||||
|
||||
from .. import init as init
|
||||
|
||||
from .vanilla import *
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
|
||||
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
None: VanillaClassifier,
|
||||
'1d': VanillaClassifier,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
_parallel_embedding = {'3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
None: VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel is None:
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[tensor_parallel](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
@ -1,8 +0,0 @@
|
||||
from ._vit import (ViTBlock, VanillaViTAttention, VanillaViTBlock, VanillaViTDropPath,
|
||||
VanillaViTHead, VanillaViTMLP, VanillaViTPatchEmbedding)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ViTBlock', 'VanillaViTAttention', 'VanillaViTBlock', 'VanillaViTDropPath',
|
||||
'VanillaViTHead', 'VanillaViTMLP', 'VanillaViTPatchEmbedding'
|
||||
]
|
@ -1,301 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from colossalai.builder import build_layer
|
||||
from colossalai.registry import LAYERS
|
||||
from .._common_utils import to_2tuple
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTBlock(nn.Module):
|
||||
"""Vision Transformer block
|
||||
|
||||
:param attention_cfg: config of attention layer
|
||||
:type attention_cfg: dict
|
||||
:param droppath_cfg: config of drop path
|
||||
:type droppath_cfg: dict
|
||||
:param mlp_cfg: config of MLP layer
|
||||
:type mlp_cfg: dict
|
||||
:param norm_cfg: config of normlization layer
|
||||
:type norm_cfg: dict
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
attention_cfg: dict,
|
||||
droppath_cfg: dict,
|
||||
mlp_cfg: dict,
|
||||
norm_cfg: dict,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = build_layer(norm_cfg)
|
||||
self.attn = build_layer(attention_cfg)
|
||||
self.drop_path = build_layer(
|
||||
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
|
||||
self.norm2 = build_layer(norm_cfg)
|
||||
self.mlp = build_layer(mlp_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTPatchEmbedding(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: size of a patch
|
||||
:type patch_size: int
|
||||
:param in_chans: input channels
|
||||
:type in_chans: int
|
||||
:param embed_dim: embedding dimension
|
||||
:type embed_dim: int
|
||||
:param norm_layer: layer norm class, defaults to None
|
||||
:type norm_layer: Callable
|
||||
:param flattern: whether flatten the output
|
||||
:type flatten: bool
|
||||
:param drop: dropout rate
|
||||
:type drop: float
|
||||
"""
|
||||
|
||||
def __init__(self, img_size, patch_size, in_chans, embed_dim, norm_layer=None, flatten=True, drop=0.):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim,
|
||||
kernel_size=patch_size, stride=patch_size)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTMLP(nn.Module):
|
||||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||
|
||||
:param in_features: input channels
|
||||
:type in_features: int
|
||||
:param hidden_features: channels of the output of the first dense layer
|
||||
:type hidden_features: int
|
||||
:param hidden_features: channels of the output of the second dense layer
|
||||
:type hidden_features: int
|
||||
:param act_layer: activation function
|
||||
:type act_layer: Callable
|
||||
:param drop: dropout rate
|
||||
:type drop: float
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
:param drop_prob: probability for dropout
|
||||
:type drop_prob: float
|
||||
:param training: whether it is training mode
|
||||
:type training: bool
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
# work with diff dim tensors, not just 2D ConvNets
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + \
|
||||
torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTDropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
:param drop_prob: probability for dropout
|
||||
:type drop_path: float
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=0.):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTAttention(nn.Module):
|
||||
"""Vanilla attention layer of Vision Transformer
|
||||
|
||||
:param dim: dimension of input tensor
|
||||
:type dim: int
|
||||
:param num_heads: number of attention heads
|
||||
:type num_heads: int, optional
|
||||
:param qkv_bias: enable bias for qkv if True, defaults to False
|
||||
:type qkv_bias: bool, optional
|
||||
:param attn_drop: dropout probability for attention layer, defaults to 0.
|
||||
:type attn_drop: float, optional
|
||||
:param proj_drop: dropout probability for linear layer, defaults to 0.
|
||||
:type proj_drop: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
|
||||
self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
# make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTBlock(nn.Module):
|
||||
|
||||
"""Vanilla Vision Transformer block
|
||||
|
||||
:param dim: dimension of input tensor
|
||||
:type dim: int
|
||||
:param num_heads: number of attention heads
|
||||
:type num_heads: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
|
||||
:type mlp_ratio: float, optional
|
||||
:param qkv_bias: enable bias for qkv if True, defaults to False
|
||||
:type qkv_bias: bool, optional
|
||||
:param drop: dropout probability, defaults to 0.
|
||||
:type drop: float, optional
|
||||
:param attn_drop: dropout probability for attention layer, defaults to 0.
|
||||
:type attn_drop: float, optional
|
||||
:param drop_path: drop path probability, defaults to 0.
|
||||
:type drop_path: float, optional
|
||||
:param act_layer: activation function, defaults to nn.GELU
|
||||
:type act_layer: torch.nn.Module, optional
|
||||
:param norm_layer: normalization layer, defaults to nn.LayerNorm
|
||||
:type norm_layer: torch.nn.Module, optional
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = LAYERS.get_module('VanillaViTAttention')(dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = LAYERS.get_module('VanillaViTDropPath')(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaViTHead(nn.Module):
|
||||
"""Output layer of vanilla Vision Transformer
|
||||
|
||||
:param in_features: size of input tensor
|
||||
:type in_features: int
|
||||
:param intermediate_features: hidden size
|
||||
:type intermediate_features: int
|
||||
:param out_features: size of output tensor
|
||||
:type out_features: int
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
intermediate_features,
|
||||
out_features,
|
||||
bias=True
|
||||
):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features, intermediate_features, bias=bias)
|
||||
self.act = nn.Tanh()
|
||||
self.linear_2 = nn.Linear(
|
||||
intermediate_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0, :].squeeze(1)
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
@ -1,11 +1,4 @@
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
from ._transformer import TransformerMLP1D, TransformerSelfAttention1D, TransformerLayer1D
|
||||
from ._vit import ViTMLP1D, ViTSelfAttention1D, ViTHead1D, ViTPatchEmbedding1D, ViTTokenFuser1D, ViTHead
|
||||
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Linear1D_Col', 'Linear1D_Row', 'ViTMLP1D', 'ViTSelfAttention1D', 'ViTHead1D', 'ViTPatchEmbedding1D', 'ViTTokenFuser1D',
|
||||
'TransformerMLP1D', 'TransformerSelfAttention1D', 'TransformerLayer1D', 'LayerNorm1D', 'ViTHead'
|
||||
]
|
||||
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
|
||||
|
@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
import math
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Tuple
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from .._common_utils import divide, ACT2FN
|
||||
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
|
||||
split_forward_gather_backward
|
||||
from ..base_layer import ParallelLayer
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerMLP1D(ParallelLayer):
|
||||
"""MLP.
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int = 4.0,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
super(TransformerMLP1D, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.skip_bias_add = skip_bias_add
|
||||
# Project to h * mlp_ratio.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
bias=not skip_bias_add,
|
||||
dtype=dtype,
|
||||
gather_output = False,
|
||||
)
|
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
|
||||
f'activation function can only be {list(ACT2FN.keys())}'
|
||||
self.activation_func = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
bias=not skip_bias_add,
|
||||
dtype=dtype,
|
||||
parallel_input = True,
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
# self.layernorm = LayerNorm1D(in_features, dtype=dtype)
|
||||
self.layernorm = nn.LayerNorm(in_features, dtype=dtype)
|
||||
def forward(self, x):
|
||||
if self.skip_bias_add:
|
||||
intermediate_output, _ = self.dense_1(x)
|
||||
else:
|
||||
intermediate_output = self.dense_1(x)
|
||||
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
output, _ = self.dense_2(intermediate_output)
|
||||
else:
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
output = self.layernorm(x + output)
|
||||
return output
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttention1D(ParallelLayer):
|
||||
"""Self attention layer for 1D parallel Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layer
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layer
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.num_attention_heads = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
# need to re-enable torch grad to enable fused optimization.
|
||||
# self.layernorm = LayerNorm1D(
|
||||
# hidden_size,
|
||||
# dtype=dtype)
|
||||
self.layernorm = nn.LayerNorm(
|
||||
hidden_size,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
attention_output = self.layernorm(hidden_states + output)
|
||||
|
||||
return attention_output
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerLayer1D(ParallelLayer):
|
||||
"""Transformer layer which contains a self-attention layer and a MLP layer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: float, optional
|
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type attention_dropout_prob: float, optional
|
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type hidden_dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = TransformerSelfAttention1D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.mlp = TransformerMLP1D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
output = self.mlp(attention_output)
|
||||
return output
|
@ -1,6 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from .._common_utils import divide
|
||||
|
||||
|
||||
@ -15,4 +20,128 @@ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
|
||||
|
||||
|
||||
def _reduce(input_, parallel_mode):
|
||||
# skip if only one rank involved
|
||||
if gpc.get_world_size(parallel_mode) == 1:
|
||||
return input_
|
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, parallel_mode, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""Pass the input to the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
ctx.mode = parallel_mode
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.mode), None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode):
|
||||
return _reduce(input_, parallel_mode)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _split(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatinate."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, parallel_mode, dim):
|
||||
ctx.mode = parallel_mode
|
||||
ctx.dim = dim
|
||||
return _gather(input_, parallel_mode, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def reduce_grad(input_, parallel_mode):
|
||||
return _ReduceGrad.apply(input_, parallel_mode)
|
||||
|
||||
|
||||
def reduce_input(input_, parallel_mode):
|
||||
return _ReduceInput.apply(input_, parallel_mode)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, parallel_mode, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, parallel_mode, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
||||
|
@ -1,411 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from colossalai import context
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from ..base_layer import ParallelLayer
|
||||
from .._common_utils import to_2tuple
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP1D(ParallelLayer):
|
||||
"""MLP layer for 1D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of each input sample
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim
|
||||
:type mlp_ratio: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.skip_bias_add = skip_bias_add
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear1D_Col(
|
||||
self.in_features,
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
dtype=dtype,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_dense_1_add_bias,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init
|
||||
)
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear1D_Row(
|
||||
int(self.mlp_ratio * self.in_features),
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
init_weight=weight_init, init_bias=weight_init
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention1D(ParallelLayer):
|
||||
"""Self-attention layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layers
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
|
||||
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
|
||||
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
init_bias = 'zero'
|
||||
else:
|
||||
init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear1D_Col(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear1D_Row(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
parallel_input=True,
|
||||
init_weight=weight_init, init_bias=init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.hidden_size_per_partition,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead1D(ParallelLayer):
|
||||
"""Output layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
init_weight = 'zero'
|
||||
init_bias = 'zero'
|
||||
else:
|
||||
init_weight = weight_init
|
||||
init_bias = weight_init
|
||||
|
||||
self.linear = Linear1D_Col(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
gather_output=True,
|
||||
init_weight=init_weight,
|
||||
init_bias=init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead(ParallelLayer):
|
||||
"""Output layer for 1D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype
|
||||
)
|
||||
self._broadcast_linear_params()
|
||||
|
||||
def _broadcast_linear_params(self) -> None:
|
||||
self.to(get_current_device())
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
|
||||
|
||||
dist.broadcast(self.linear.weight, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
dist.broadcast(self.linear.bias, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding1D(ParallelLayer):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size
|
||||
)
|
||||
|
||||
if weight_init == 'jax':
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
# sync
|
||||
self._broadcast_conv_params()
|
||||
|
||||
def _broadcast_conv_params(self) -> None:
|
||||
self.to(get_current_device())
|
||||
ranks = gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)
|
||||
|
||||
dist.broadcast(self.proj.weight, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
dist.broadcast(self.proj.bias, src=ranks[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser1D(ParallelLayer):
|
||||
"""
|
||||
Fuse cls token and pos embedding to the input
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param drop_rate: dropout probability, defaults to 0.
|
||||
:type drop_rate: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
drop_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
1, self.num_patches + 1, self.embed_dim))
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
# move to cuda before broadcast
|
||||
self.to(get_current_device())
|
||||
dist.broadcast(self.pos_embed,
|
||||
src=gpc.get_ranks_in_group(ParallelMode.TENSOR)[0],
|
||||
group=gpc.get_group(ParallelMode.TENSOR))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x.contiguous()
|
@ -3,25 +3,24 @@
|
||||
|
||||
import math
|
||||
import numbers
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Tuple
|
||||
import importlib
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import FusedLayerNormAffineFunction1D
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from .._parallel_utilities import reduce_grad, reduce_input, gather_forward_split_backward, \
|
||||
split_forward_gather_backward
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import FusedLayerNormAffineFunction1D
|
||||
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@ -44,79 +43,46 @@ class Linear1D_Col(ParallelLayer):
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
:type gather_output: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
output_size: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = output_size
|
||||
self.out_features = out_features
|
||||
self.gather_output = gather_output
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.output_size_per_partition = divide(output_size, gpc.tensor_parallel_size)
|
||||
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.in_features,
|
||||
**factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
**factory_kwargs))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
@ -133,8 +99,7 @@ class Linear1D_Col(ParallelLayer):
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(
|
||||
output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
else:
|
||||
output = output_parallel
|
||||
if self.skip_bias_add:
|
||||
@ -158,17 +123,15 @@ class Linear1D_Row(ParallelLayer):
|
||||
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
|
||||
:type parallel_input: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
parallel_input: bool = False,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@ -186,58 +149,22 @@ class Linear1D_Row(ParallelLayer):
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.out_features,
|
||||
self.input_size_per_partition,
|
||||
**factory_kwargs))
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.out_features,
|
||||
**factory_kwargs
|
||||
))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
dist.broadcast(self.bias,
|
||||
src=gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
@ -248,8 +175,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
if self.parallel_input:
|
||||
input_ = input_
|
||||
else:
|
||||
input_ = split_forward_gather_backward(
|
||||
input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
@ -263,12 +189,13 @@ class Linear1D_Row(ParallelLayer):
|
||||
|
||||
@LAYERS.register_module
|
||||
class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||
|
||||
""" Experimental
|
||||
"""
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super(MixedFusedLayerNorm1D, self).__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
normalized_shape = (normalized_shape, )
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.weight = Parameter(torch.Tensor(*normalized_shape))
|
||||
@ -280,5 +207,4 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input):
|
||||
return FusedLayerNormAffineFunction1D.apply(
|
||||
input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
|
@ -1,11 +1,6 @@
|
||||
from ._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D, Add_Bias_2D, matmul_2d
|
||||
from ._transformer import TransformerMLP2D, TransformerSelfAttention2D, TransformerLayer2D
|
||||
from ._vit import ViTMLP2D, ViTSelfAttention2D, ViTHead2D, ViTPatchEmbedding2D, ViTTokenFuser2D, ViTInputSplitter2D
|
||||
from .layers import Linear2D, LayerNorm2D
|
||||
from ._operation import reduce_by_batch_2d, split_batch_2d
|
||||
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
|
||||
|
||||
__all__ = [
|
||||
'Matmul_AB_2D', 'Matmul_ABT_2D', 'Matmul_ATB_2D', 'Add_Bias_2D', 'matmul_2d',
|
||||
'TransformerMLP2D', 'TransformerSelfAttention2D', 'TransformerLayer2D',
|
||||
'ViTMLP2D', 'ViTSelfAttention2D', 'ViTHead2D', 'ViTPatchEmbedding2D', 'ViTTokenFuser2D', 'ViTInputSplitter2D',
|
||||
'Linear2D', 'LayerNorm2D'
|
||||
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||
]
|
||||
|
@ -1,24 +1,25 @@
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
def matmul_2d(a,
|
||||
b,
|
||||
summa_dim,
|
||||
out_shape,
|
||||
row_rank=None,
|
||||
col_rank=None,
|
||||
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
|
||||
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
|
||||
):
|
||||
def matmul_2d(
|
||||
a,
|
||||
b,
|
||||
summa_dim,
|
||||
out_shape,
|
||||
row_rank=None,
|
||||
col_rank=None,
|
||||
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
|
||||
col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
|
||||
):
|
||||
"""Matrix multiplication for 2D parallelism
|
||||
:param a: matrix :math:`A`
|
||||
:type a: torch.tensor
|
||||
@ -44,16 +45,87 @@ def matmul_2d(a,
|
||||
if col_rank is None:
|
||||
col_rank = gpc.get_local_rank(row_parallel_mode)
|
||||
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(
|
||||
ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = summa_dim ** 2
|
||||
tensor_parallel_size = summa_dim**2
|
||||
return Matmul_AB_2D(a, b, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode, col_parallel_mode,
|
||||
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size
|
||||
)
|
||||
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
|
||||
class classifier_2d(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
B_temp = all_gather(B, -1, col_parallel_mode)
|
||||
if ctx:
|
||||
ctx.save_for_backward(A, B_temp)
|
||||
|
||||
C = torch.matmul(A, B_temp.transpose(0, 1))
|
||||
|
||||
C = all_reduce(C, row_parallel_mode)
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
if bias is not None:
|
||||
C = C + bias
|
||||
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
ctx.summa_dim = summa_dim
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.A_shape = A_shape
|
||||
ctx.B_shape = B_shape
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||||
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||||
ctx.tensor_parallel_size = tensor_parallel_size
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = torch.matmul(output_grad, B)
|
||||
A_grad = A_grad.reshape(ctx.A_shape)
|
||||
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
|
||||
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
|
||||
B_grad = B_grad.reshape(ctx.B_shape)
|
||||
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
||||
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
||||
|
||||
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_AB_2D(torch.autograd.Function):
|
||||
@ -61,19 +133,21 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int) -> Tensor:
|
||||
def forward(
|
||||
ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
# A: [b / q, s, h / q] -> [(b * s) / q, h / q]
|
||||
# B: [h / q, s / q]
|
||||
# C: [b / q, s, s / q] -> [(b * s) / q, s / q]
|
||||
@ -116,15 +190,9 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||
for i in range(summa_dim):
|
||||
if i != summa_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
|
||||
src=src_a + 1,
|
||||
group=row_group,
|
||||
async_op=True)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + summa_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True)
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
@ -157,28 +225,14 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2D.apply(
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.apply(
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_ABT_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_ATB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -187,20 +241,21 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int
|
||||
) -> Tensor:
|
||||
def forward(
|
||||
ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], \
|
||||
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
|
||||
@ -238,10 +293,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||
for i in range(summa_dim):
|
||||
if i != summa_dim - 1:
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + summa_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur], src=src_b + summa_dim, group=col_group, async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
@ -287,28 +339,14 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||
A, B = ctx.saved_tensors
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_AB_2D.apply(
|
||||
output_grad, B,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2D.apply(
|
||||
output_grad, A,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_AB_2D.apply(output_grad, B, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_ATB_2D.apply(output_grad, A, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -317,20 +355,21 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int
|
||||
) -> Tensor:
|
||||
def forward(
|
||||
ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
summa_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
assert A.shape[-2] == B.shape[-2], \
|
||||
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
|
||||
@ -368,10 +407,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||
for i in range(summa_dim):
|
||||
if i != summa_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
|
||||
src=src_a + 1,
|
||||
group=row_group,
|
||||
async_op=True)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
@ -417,61 +453,38 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||
A, B = ctx.saved_tensors
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2D.apply(
|
||||
B, output_grad,
|
||||
ctx.summa_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2D.apply(
|
||||
A, output_grad,
|
||||
ctx.summa_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_ABT_2D.apply(B, output_grad, ctx.summa_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_AB_2D.apply(A, output_grad, ctx.summa_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.row_parallel_mode, ctx.col_parallel_mode, ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank, ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Add_Bias_2D(torch.autograd.Function):
|
||||
class add_bias_2d(torch.autograd.Function):
|
||||
"""Matrix add bias: :math:`C = A + b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
bias: Tensor,
|
||||
output_size_per_partition: int,
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
skip_bias_add: bool,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int
|
||||
) -> Tensor:
|
||||
if row_rank == 0:
|
||||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(
|
||||
output_size_per_partition,
|
||||
dtype=bias.dtype,
|
||||
device=get_current_device())
|
||||
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(bias_temp, src=src_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
def forward(
|
||||
ctx: Any,
|
||||
input_: Tensor,
|
||||
bias: Tensor,
|
||||
output_size_per_partition: int,
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
skip_bias_add: bool,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
bias_temp = all_gather(bias, -1, col_parallel_mode)
|
||||
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
@ -486,62 +499,33 @@ class Add_Bias_2D(torch.autograd.Function):
|
||||
if skip_bias_add:
|
||||
return bias_temp
|
||||
else:
|
||||
output = input + bias_temp
|
||||
output = input_ + bias_temp
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
row_rank = ctx.row_rank
|
||||
col_rank = ctx.col_rank
|
||||
row_parallel_mode = ctx.row_parallel_mode
|
||||
col_parallel_mode = ctx.col_parallel_mode
|
||||
data_parallel_rank = ctx.data_parallel_rank
|
||||
pipeline_parallel_rank = ctx.pipeline_parallel_rank
|
||||
pipeline_parallel_size = ctx.pipeline_parallel_size
|
||||
tensor_parallel_size = ctx.tensor_parallel_size
|
||||
|
||||
if ctx.bias:
|
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(output_grad, dst=dst_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
if row_rank == 0:
|
||||
return None, output_grad, None, None, None, None, None, None, None, None, None, None
|
||||
else:
|
||||
# for compatibility with zero optimizer, no grad should be None
|
||||
grad_tmp = torch.zeros_like(output_grad)
|
||||
return None, grad_tmp, None, None, None, None, None, None, None, None, None, None
|
||||
grad = reduce_scatter(output_grad, -1, col_parallel_mode)
|
||||
return None, grad, None, None, None, None, None, None, None, None, None, None
|
||||
else:
|
||||
reduce_dim = tuple(range(output_grad.ndim - 1))
|
||||
reduce = torch.sum(output_grad, dim=reduce_dim)
|
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(reduce, dst=dst_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
if row_rank == 0:
|
||||
return output_grad, reduce, None, None, None, None, None, None, None, None, None, None
|
||||
else:
|
||||
# for compatibility with zero optimizer, no grad should be None
|
||||
reduce_tmp = torch.zeros_like(reduce)
|
||||
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None
|
||||
grad = reduce_scatter(reduce, -1, col_parallel_mode)
|
||||
return output_grad, grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class _LayerNorm_2D(torch.autograd.Function):
|
||||
|
||||
class layernorm_2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
E_x: Tensor,
|
||||
Var_x: Tensor,
|
||||
hidden_size: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
input = input - E_x
|
||||
input_ = input_ - E_x
|
||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
||||
ctx.normalized_shape = hidden_size
|
||||
output = input * Var_x
|
||||
output = input_ * Var_x
|
||||
ctx.save_for_backward(output, Var_x)
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
@ -555,14 +539,11 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||
x, Var_x = ctx.saved_tensors
|
||||
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
|
||||
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(
|
||||
output_grad_sum, group=gpc.get_group(row_parallel_mode))
|
||||
torch.distributed.all_reduce(output_grad_sum, group=gpc.get_group(row_parallel_mode))
|
||||
output_grad_sum /= ctx.normalized_shape
|
||||
|
||||
output_grad_mul_x_sum = torch.sum(
|
||||
output_grad * x, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(
|
||||
output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
|
||||
output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(output_grad_mul_x_sum, group=gpc.get_group(row_parallel_mode))
|
||||
output_grad_mul_x_sum /= ctx.normalized_shape
|
||||
|
||||
input_grad = output_grad.clone()
|
||||
@ -573,69 +554,28 @@ class _LayerNorm_2D(torch.autograd.Function):
|
||||
return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
# class Sum_2D(torch.autograd.Function):
|
||||
#
|
||||
# @staticmethod
|
||||
# def forward(ctx: Any,
|
||||
# inputs: Tensor,
|
||||
# dim: int,
|
||||
# summa_dim: int,
|
||||
# row_parallel_mode: ParallelMode,
|
||||
# keepdim: bool = False) -> Tensor:
|
||||
# # input: [b/q, s, h/q]
|
||||
# empty_cache()
|
||||
# ctx.save_for_backward(inputs)
|
||||
# # sum: [b/q, s]
|
||||
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
|
||||
# torch.distributed.all_reduce(out, group=gpc.get_group(row_parallel_mode))
|
||||
# return out
|
||||
#
|
||||
# @staticmethod
|
||||
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# with torch.no_grad():
|
||||
# inputs = ctx.saved_tensors
|
||||
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
|
||||
# return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class AllGatherLast(torch.autograd.Function):
|
||||
|
||||
class all_gather_weight_2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
summa_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.dim = dim
|
||||
ctx.summa_dim = summa_dim
|
||||
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
last_dim = summa_dim * inputs.size(-1)
|
||||
outputs_shape = (last_dim,) + inputs.shape[:-1]
|
||||
outputs = torch.empty(
|
||||
outputs_shape, dtype=inputs.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(outputs.chunk(summa_dim, dim=0)),
|
||||
inputs.permute(2, 0, 1).contiguous(),
|
||||
group=gpc.get_group(col_parallel_mode)
|
||||
)
|
||||
outputs = outputs.permute(1, 2, 0).contiguous()
|
||||
outputs = all_gather(inputs, dim, col_parallel_mode)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad = output_grad.chunk(ctx.summa_dim, dim=-1)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None
|
||||
grad = output_grad.chunk(ctx.summa_dim, dim=ctx.dim)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None, None
|
||||
|
||||
|
||||
class SplitFirst(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
summa_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.summa_dim = summa_dim
|
||||
ctx.batch_size = inputs.size(0)
|
||||
ctx.para_mode = col_parallel_mode
|
||||
@ -647,12 +587,33 @@ class SplitFirst(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(
|
||||
grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.summa_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(list(grad.chunk(ctx.summa_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode))
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
||||
|
||||
|
||||
class reduce_by_batch_2d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor
|
||||
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
from colossalai.registry import LAYERS
|
||||
from .layers import Linear2D, LayerNorm2D
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerMLP2D(ParallelLayer):
|
||||
"""
|
||||
MLP will take the input with h hidden state, project it to mlp_ratio * h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
applied.
|
||||
|
||||
:param in_features: the size of input tensor
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: int, optional
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
|
||||
:type skip_bias_add: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int = 4.0,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.in_features = in_features
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
# Project to h * mlp_ratio.
|
||||
self.dense_1 = Linear2D(
|
||||
in_features,
|
||||
int(mlp_ratio * in_features),
|
||||
dtype=dtype,
|
||||
skip_bias_add=self.skip_bias_add
|
||||
)
|
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
|
||||
f'activation function can only be {list(ACT2FN.keys())}'
|
||||
self.activation_func = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2D(
|
||||
int(mlp_ratio * in_features),
|
||||
in_features,
|
||||
dtype=dtype,
|
||||
skip_bias_add=self.skip_bias_add
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
self.layernorm = LayerNorm2D(in_features, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.skip_bias_add:
|
||||
intermediate_output, _ = self.dense_1(x)
|
||||
else:
|
||||
intermediate_output = self.dense_1(x)
|
||||
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
output, _ = self.dense_2(intermediate_output)
|
||||
else:
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
output = self.dropout(output)
|
||||
output = self.layernorm(x + output)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttention2D(ParallelLayer):
|
||||
"""Self attention layer for 2D parallel Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layer
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layer
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query_key_value = Linear2D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.layernorm = LayerNorm2D(
|
||||
hidden_size,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
attention_output = self.layernorm(hidden_states + output)
|
||||
|
||||
return attention_output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerLayer2D(ParallelLayer):
|
||||
"""Transformer layer which contains a self-attention layer and a MLP layer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: float, optional
|
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type attention_dropout_prob: float, optional
|
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type hidden_dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = TransformerSelfAttention2D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.mlp = TransformerMLP2D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
output = self.mlp(attention_output)
|
||||
return output
|
@ -1,397 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.nn.layer._common_utils import divide, ACT2FN
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.core import global_context as gpc
|
||||
from ._operation import AllGatherLast, SplitFirst
|
||||
from .layers import Linear2D
|
||||
from .._common_utils import set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP2D(ParallelLayer):
|
||||
"""MLP layer for 2D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of each input sample
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim
|
||||
:type mlp_ratio: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear2D(
|
||||
self.in_features,
|
||||
self.mlp_ratio * self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=weight_init,
|
||||
skip_bias_add=skip_dense_1_add_bias
|
||||
)
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2D(
|
||||
self.mlp_ratio * self.in_features,
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=weight_init
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention2D(ParallelLayer):
|
||||
"""Self-attention layer for 2D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layers
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: whether to checkpoint the layer, defaults to False
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(num_attention_heads, self.summa_dim)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear2D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=self.init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init, init_bias=self.init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead2D(ParallelLayer):
|
||||
"""Output layer for 2D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_weight = weight_init
|
||||
self.init_bias = weight_init
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.linear = Linear2D(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
init_weight=self.init_weight, init_bias=self.init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding2D(ParallelLayer):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim // (self.summa_dim ** 2)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=get_current_device()
|
||||
)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
if weight_init == 'jax':
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTInputSplitter2D(ParallelLayer):
|
||||
"""Split the input tensor for 2D parallel Vision Transformer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = AllGatherLast.apply(
|
||||
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
x = SplitFirst.apply(
|
||||
x, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser2D(ParallelLayer):
|
||||
"""
|
||||
Fuse cls token and pos embedding to the input
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param drop_rate: dropout probability, defaults to 0.
|
||||
:type drop_rate: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
drop_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
(1, 1, self.embed_dim // (self.summa_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
(1, self.num_patches + 1, self.embed_dim // (self.summa_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_token = AllGatherLast.apply(
|
||||
self.cls_token, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
cls_token = cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
pos_embed = AllGatherLast.apply(
|
||||
self.pos_embed, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
x = x + pos_embed
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.pos_drop(x)
|
||||
return x
|
@ -1,18 +1,22 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter, init as init
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import Matmul_AB_2D, Add_Bias_2D, _LayerNorm_2D
|
||||
from ._utils import get_summa_dim_from_env, assert_summa_initialization
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
|
||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@ -30,15 +34,14 @@ class Linear2D(ParallelLayer):
|
||||
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
|
||||
:type skip_bias_add: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
@ -52,118 +55,57 @@ class Linear2D(ParallelLayer):
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(
|
||||
self.in_features, self.summa_dim)
|
||||
self.hidden_size_per_partition = divide(
|
||||
self.out_features, self.summa_dim)
|
||||
self.input_size_per_partition = divide(self.in_features, self.summa_dim)
|
||||
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.input_size_per_partition,
|
||||
self.hidden_size_per_partition,
|
||||
**factory_kwargs))
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
|
||||
|
||||
# create bias, shape: [h/q]
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.hidden_size_per_partition,
|
||||
**factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(divide(self.out_features, self.summa_dim**2), **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
# initialize parameters
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/q, n/q, k/q]
|
||||
# output: [m/q, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
|
||||
|
||||
output = Matmul_AB_2D.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.summa_dim,
|
||||
out_shape,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = Matmul_AB_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
|
||||
if self.bias is not None:
|
||||
if self.skip_bias_add:
|
||||
bias = Add_Bias_2D.apply(
|
||||
None,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
bias = add_bias_2d.apply(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
return output, bias
|
||||
else:
|
||||
output = Add_Bias_2D.apply(
|
||||
output,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
False,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
output = add_bias_2d.apply(output, self.bias, self.hidden_size_per_partition, self.row_rank,
|
||||
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
False, self.data_parallel_rank, self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
return output
|
||||
else:
|
||||
return output
|
||||
@ -183,12 +125,7 @@ class LayerNorm2D(ParallelLayer):
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape: int,
|
||||
eps: float = 1e-05,
|
||||
dtype=None
|
||||
):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
# layer norm config
|
||||
@ -202,63 +139,252 @@ class LayerNorm2D(ParallelLayer):
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.partitioned_partition = divide(normalized_shape, self.summa_dim)
|
||||
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, self.summa_dim**2)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(
|
||||
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
E_x /= self.normalized_shape
|
||||
|
||||
# Var_x in the block below is the sum of input^2
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(
|
||||
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
Var_x /= self.normalized_shape
|
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
# this time 1/sqrt(Var_x + epsilon)
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
output = _LayerNorm_2D.apply(x, E_x, Var_x, self.normalized_shape,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL)
|
||||
bias = Add_Bias_2D.apply(
|
||||
None, self.beta, self.partitioned_partition,
|
||||
self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
scale = Add_Bias_2D.apply(
|
||||
None, self.gamma, self.partitioned_partition,
|
||||
self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
output = layernorm_2d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
|
||||
ParallelMode.PARALLEL_2D_COL)
|
||||
bias = add_bias_2d.apply(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
scale = add_bias_2d.apply(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class PatchEmbedding2D(ParallelLayer):
|
||||
""" 2D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = embed_size // (self.summa_dim**2)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.summa_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.summa_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
fan_out = self.embed_size
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = all_gather_weight_2d.apply(self.cls_token, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
pos_embed = all_gather_weight_2d.apply(self.pos_embed, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
cls_token = cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
output = output + pos_embed
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Embedding2D(ParallelLayer):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, self.summa_dim**2)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Classifier2D(ParallelLayer):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
assert_summa_initialization()
|
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(self.in_features, self.summa_dim**2)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_COL)[0]
|
||||
row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2D_ROW)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||
|
||||
return classifier_2d.apply(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank,
|
||||
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
|
@ -1,12 +1,7 @@
|
||||
from ._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D, Add_Bias_2p5D
|
||||
from ._transformer import TransformerMLP2p5D, TransformerSelfAttention2p5D, TransformerLayer2p5D
|
||||
from ._vit import ViTMLP2p5D, ViTSelfAttention2p5D, ViTHead2p5D, ViTPatchEmbedding2p5D, ViTTokenFuser2p5D, ViTInputSplitter2p5D
|
||||
from .layers import Linear2p5D, LayerNorm2p5D
|
||||
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
|
||||
|
||||
__all__ = [
|
||||
'Matmul_AB_2p5D', 'Matmul_ABT_2p5D', 'Matmul_ATB_2p5D', 'Add_Bias_2p5D',
|
||||
'TransformerMLP2p5D', 'TransformerSelfAttention2p5D', 'TransformerLayer2p5D',
|
||||
'ViTMLP2p5D', 'ViTSelfAttention2p5D', 'ViTHead2p5D', 'ViTPatchEmbedding2p5D', 'ViTTokenFuser2p5D',
|
||||
'ViTInputSplitter2p5D',
|
||||
'Linear2p5D', 'LayerNorm2p5D'
|
||||
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'Embedding2p5D'
|
||||
]
|
||||
|
@ -2,11 +2,11 @@ from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
@ -22,25 +22,92 @@ def get_parallel_rank(parallel_mode: ParallelMode):
|
||||
return gpc.get_local_rank(parallel_mode)
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
|
||||
class classifier_2p5d(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
bias,
|
||||
tesseract_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
) -> Tensor:
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
B_temp = all_gather(B, -1, col_parallel_mode)
|
||||
if ctx:
|
||||
ctx.save_for_backward(A, B_temp)
|
||||
|
||||
C = torch.matmul(A, B_temp.transpose(0, 1))
|
||||
|
||||
C = all_reduce(C, row_parallel_mode)
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
if bias is not None:
|
||||
C = C + bias
|
||||
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.row_rank = row_rank
|
||||
ctx.col_rank = col_rank
|
||||
ctx.row_parallel_mode = row_parallel_mode
|
||||
ctx.col_parallel_mode = col_parallel_mode
|
||||
ctx.A_shape = A_shape
|
||||
ctx.B_shape = B_shape
|
||||
ctx.data_parallel_rank = data_parallel_rank
|
||||
ctx.pipeline_parallel_rank = pipeline_parallel_rank
|
||||
ctx.pipeline_parallel_size = pipeline_parallel_size
|
||||
ctx.tensor_parallel_size = tensor_parallel_size
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
|
||||
with torch.no_grad():
|
||||
A_grad = torch.matmul(output_grad, B)
|
||||
A_grad = A_grad.reshape(ctx.A_shape)
|
||||
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
|
||||
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
|
||||
B_grad = B_grad.reshape(ctx.B_shape)
|
||||
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
||||
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
||||
|
||||
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int) -> Tensor:
|
||||
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
|
||||
# B: [h / dq, s / q]
|
||||
@ -59,8 +126,8 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
@ -100,52 +167,26 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2p5D.apply(
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.apply(
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB^T`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int
|
||||
) -> Tensor:
|
||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int) -> Tensor:
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], \
|
||||
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
|
||||
@ -197,50 +238,25 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_AB_2p5D.apply(
|
||||
output_grad, B,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_ATB_2p5D.apply(
|
||||
output_grad, A,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = A^TB`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
tesseract_dim: int,
|
||||
out_shape: Tuple[int, ...],
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
row_parallel_mode: ParallelMode,
|
||||
col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
|
||||
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
|
||||
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int):
|
||||
|
||||
assert A.shape[-2] == B.shape[-2], \
|
||||
@ -261,14 +277,12 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a,
|
||||
group=get_parallel_group(row_parallel_mode))
|
||||
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
|
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
||||
src_c = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c,
|
||||
group=get_parallel_group(col_parallel_mode))
|
||||
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
|
||||
if i == row_rank:
|
||||
C = C_temp.clone()
|
||||
|
||||
@ -295,59 +309,30 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_2p5D.apply(
|
||||
B, output_grad,
|
||||
ctx.tesseract_dim, ctx.A_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
B_grad = Matmul_AB_2p5D.apply(
|
||||
A, output_grad,
|
||||
ctx.tesseract_dim, ctx.B_shape,
|
||||
ctx.row_rank, ctx.col_rank, ctx.dep_rank,
|
||||
ctx.row_parallel_mode,
|
||||
ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank,
|
||||
ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size,
|
||||
ctx.tensor_parallel_size
|
||||
)
|
||||
A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
|
||||
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
|
||||
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
|
||||
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Add_Bias_2p5D(torch.autograd.Function):
|
||||
"""Matrix add bias: :math:`C = A + b`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
bias: Tensor,
|
||||
output_size_per_partition: int,
|
||||
tesseract_dim: int,
|
||||
row_rank: int,
|
||||
col_rank: int,
|
||||
dep_rank: int,
|
||||
col_parallel_mode: ParallelMode,
|
||||
skip_bias_add: bool,
|
||||
data_parallel_rank: int,
|
||||
pipeline_parallel_rank: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int
|
||||
) -> Tensor:
|
||||
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
|
||||
row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
|
||||
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int) -> Tensor:
|
||||
if row_rank == 0:
|
||||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(
|
||||
output_size_per_partition,
|
||||
dtype=bias.dtype,
|
||||
device=get_current_device())
|
||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||
src_rank = col_rank + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
@ -407,14 +392,10 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class _LayerNorm_2p5D(torch.autograd.Function):
|
||||
class layernorm_2p5d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx: Any,
|
||||
input: Tensor,
|
||||
E_x: Tensor,
|
||||
Var_x: Tensor,
|
||||
hidden_size: int,
|
||||
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
|
||||
row_parallel_mode: ParallelMode) -> Tensor:
|
||||
input = input - E_x
|
||||
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
|
||||
@ -432,14 +413,11 @@ class _LayerNorm_2p5D(torch.autograd.Function):
|
||||
# in here, Var_x = 1 / sqrt(Var[x] + eps), x = (x - E[x]) * Var_x
|
||||
with torch.no_grad():
|
||||
output_grad_sum = torch.sum(output_grad, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(
|
||||
output_grad_sum, group=get_parallel_group(row_parallel_mode))
|
||||
torch.distributed.all_reduce(output_grad_sum, group=get_parallel_group(row_parallel_mode))
|
||||
output_grad_sum /= ctx.hidden_size
|
||||
|
||||
output_grad_mul_x_sum = torch.sum(
|
||||
output_grad * x, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(
|
||||
output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
|
||||
output_grad_mul_x_sum = torch.sum(output_grad * x, dim=-1, keepdim=True)
|
||||
torch.distributed.all_reduce(output_grad_mul_x_sum, group=get_parallel_group(row_parallel_mode))
|
||||
output_grad_mul_x_sum /= ctx.hidden_size
|
||||
|
||||
input_grad = output_grad.clone()
|
||||
@ -450,105 +428,28 @@ class _LayerNorm_2p5D(torch.autograd.Function):
|
||||
return input_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
# class Sum_2p5D(torch.autograd.Function):
|
||||
# """Compute the sum of input tensors
|
||||
# """
|
||||
|
||||
# @staticmethod
|
||||
# def forward(ctx,
|
||||
# inputs,
|
||||
# dim,
|
||||
# tesseract_dim,
|
||||
# row_parallel_mode,
|
||||
# keepdim=False):
|
||||
# # input: [b/q, s, h/q]
|
||||
# ctx.save_for_backward(inputs)
|
||||
# # sum: [b/q, s]
|
||||
# out = torch.sum(inputs, dim=dim, keepdim=keepdim)
|
||||
# torch.distributed.all_reduce(
|
||||
# out, group=gpc.get_group(row_parallel_mode))
|
||||
# return out
|
||||
|
||||
# @staticmethod
|
||||
# def backward(ctx, output_grad):
|
||||
# with torch.no_grad():
|
||||
# inputs = ctx.saved_tensors
|
||||
# input_grad = torch.ones(inputs.shape, dtype=output_grad.dtype)
|
||||
# return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
# class _ViT_Split_2p5D(torch.autograd.Function):
|
||||
# @staticmethod
|
||||
# @custom_fwd(cast_inputs=torch.float16)
|
||||
# def forward(ctx, inputs, batch_size,
|
||||
# tesseract_dim, tesseract_dep,
|
||||
# xz_parallel_mode):
|
||||
# # inputs: [b, s, h/q]
|
||||
# # output: [b/dq, s, h/q]
|
||||
|
||||
# ctx.BATCH_SIZE = batch_size
|
||||
# ctx.tesseract_dim = tesseract_dim
|
||||
# ctx.tesseract_dep = tesseract_dep
|
||||
# ctx.xz_parallel_mode = xz_parallel_mode
|
||||
# xz_rank = gpc.get_local_rank(xz_parallel_mode)
|
||||
# output = torch.chunk(inputs, tesseract_dep *
|
||||
# tesseract_dim, dim=0)[xz_rank]
|
||||
# output = output.clone()
|
||||
# return output
|
||||
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx, output_grad):
|
||||
# # output_grad: [b/dq, s, h/q]
|
||||
# # grads: [b, s, h/q]
|
||||
# # *
|
||||
# grads_shape = (ctx.BATCH_SIZE,) + output_grad.shape[1:]
|
||||
# grads = torch.empty(grads_shape,
|
||||
# dtype=output_grad.dtype,
|
||||
# device=get_current_device())
|
||||
# dist.all_gather(list(grads.chunk(ctx.tesseract_dim * ctx.tesseract_dep, dim=0)),
|
||||
# output_grad.contiguous(),
|
||||
# group=get_parallel_group(ctx.xz_parallel_mode))
|
||||
# return grads, None, None, None, None
|
||||
|
||||
class AllGatherLast(torch.autograd.Function):
|
||||
|
||||
class all_gather_weight_2p5d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
tesseract_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.dim = dim
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
|
||||
|
||||
last_dim = tesseract_dim * inputs.size(-1)
|
||||
outputs_shape = (last_dim,) + inputs.shape[:-1]
|
||||
outputs = torch.empty(
|
||||
outputs_shape, dtype=inputs.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(outputs.chunk(tesseract_dim, dim=0)),
|
||||
inputs.permute(2, 0, 1).contiguous(),
|
||||
group=gpc.get_group(col_parallel_mode)
|
||||
)
|
||||
outputs = outputs.permute(1, 2, 0).contiguous()
|
||||
outputs = all_gather(inputs, dim, col_parallel_mode)
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad = output_grad.chunk(ctx.tesseract_dim, dim=-1)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None
|
||||
grad = output_grad.chunk(ctx.tesseract_dim, dim=ctx.dim)[ctx.row_rank]
|
||||
return grad.contiguous(), None, None, None
|
||||
|
||||
|
||||
class SplitFirst(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
tesseract_dim: int,
|
||||
col_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.tesseract_dim = tesseract_dim
|
||||
ctx.batch_size = inputs.size(0)
|
||||
ctx.para_mode = col_parallel_mode
|
||||
@ -560,12 +461,33 @@ class SplitFirst(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(
|
||||
grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.tesseract_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
return grad, None, None
|
||||
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
|
||||
output_grad.contiguous(),
|
||||
group=gpc.get_group(ctx.para_mode))
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
|
||||
class reduce_by_batch_2p5d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
@ -1,220 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor
|
||||
|
||||
from colossalai.nn.layer._common_utils import divide
|
||||
from colossalai.registry import LAYERS
|
||||
from ._utils import assert_tesseract_initialization, \
|
||||
get_tesseract_dim_dep_from_env
|
||||
from .layers import Linear2p5D, LayerNorm2p5D
|
||||
from .._common_utils import ACT2FN
|
||||
from ..base_layer import ParallelLayer
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerMLP2p5D(ParallelLayer):
|
||||
"""
|
||||
MLP will take the input with h hidden state, project it to mlp_ratio * h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
applied.
|
||||
|
||||
:param in_features: the size of input tensor
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: int, optional
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int = 4.0,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
skip_bias_add: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.in_features = in_features
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
# Project to h * mlp_ratio.
|
||||
self.dense_1 = Linear2p5D(
|
||||
in_features,
|
||||
int(mlp_ratio * in_features),
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add
|
||||
)
|
||||
|
||||
assert act_func in ACT2FN.keys(), f'Invalid value for argument act_func, ' \
|
||||
f'activation function can only be {list(ACT2FN.keys())}'
|
||||
self.activation_func = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2p5D(
|
||||
int(mlp_ratio * in_features),
|
||||
in_features,
|
||||
dtype=dtype,
|
||||
skip_bias_add=skip_bias_add
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
self.layernorm = LayerNorm2p5D(in_features, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.skip_bias_add:
|
||||
intermediate_output, _ = self.dense_1(x)
|
||||
else:
|
||||
intermediate_output = self.dense_1(x)
|
||||
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
|
||||
if self.skip_bias_add:
|
||||
output, _ = self.dense_2(intermediate_output)
|
||||
else:
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
output = self.dropout(output)
|
||||
output = self.layernorm(x + output)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttention2p5D(ParallelLayer):
|
||||
"""Self attention layer for 2.5D parallel Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layer
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layer
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(
|
||||
num_attention_heads, self.tesseract_dim) # *
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.query_key_value = Linear2p5D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2p5D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.layernorm = LayerNorm2p5D(
|
||||
hidden_size,
|
||||
dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute((0, 2, 1, 3)).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
output = self.dropout(output)
|
||||
attention_output = self.layernorm(hidden_states + output)
|
||||
|
||||
return attention_output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerLayer2p5D(ParallelLayer):
|
||||
"""Transformer layer which contains a self-attention layer and a MLP layer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.0
|
||||
:type mlp_ratio: float, optional
|
||||
:param attention_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type attention_dropout_prob: float, optional
|
||||
:param hidden_dropout_prob: dropout probability for attention layer, defaults to 0.
|
||||
:type hidden_dropout_prob: float, optional
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention = TransformerSelfAttention2p5D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.mlp = TransformerMLP2p5D(
|
||||
in_features=hidden_size,
|
||||
dropout_prob=hidden_dropout_prob,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
attention_output = self.attention(hidden_states, attention_mask)
|
||||
output = self.mlp(attention_output)
|
||||
return output
|
@ -1,421 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn as nn, Tensor, distributed as dist
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import checkpoint
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import AllGatherLast, SplitFirst
|
||||
from ._utils import assert_tesseract_initialization, \
|
||||
get_tesseract_dim_dep_from_env
|
||||
from .layers import Linear2p5D
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..fused_bias_gelu import bias_gelu_impl
|
||||
from .._common_utils import (ACT2FN, divide, to_2tuple,
|
||||
set_tensor_parallel_attribute_by_partition)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP2p5D(ParallelLayer):
|
||||
"""MLP layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of each input sample
|
||||
:type in_features: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim
|
||||
:type mlp_ratio: int
|
||||
:param act_func: activation function, defaults to 'gelu'
|
||||
:type act_func: str, optional
|
||||
:param dropout_prob: dropout probability, defaults to 0.
|
||||
:type dropout_prob: float, optional
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.in_features = in_features
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
|
||||
if act_func == 'fused_gelu':
|
||||
self.act = bias_gelu_impl
|
||||
skip_dense_1_add_bias = True
|
||||
else:
|
||||
self.act = ACT2FN[act_func]
|
||||
skip_dense_1_add_bias = False
|
||||
|
||||
# Project to mlp_ratio * h.
|
||||
self.dense_1 = Linear2p5D(
|
||||
self.in_features,
|
||||
self.mlp_ratio * self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init,
|
||||
skip_bias_add=skip_dense_1_add_bias
|
||||
)
|
||||
|
||||
self.act = ACT2FN[act_func]
|
||||
|
||||
# Project back to h.
|
||||
self.dense_2 = Linear2p5D(
|
||||
self.mlp_ratio * self.in_features,
|
||||
self.in_features,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=weight_init
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout_prob)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.act == bias_gelu_impl:
|
||||
intermediate_output, bias = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output, bias)
|
||||
else:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.act(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
intermediate_output = self.dropout(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention2p5D(ParallelLayer):
|
||||
"""Self-attention layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout_prob: dropout probability for attention layers
|
||||
:type attention_dropout_prob: float
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: float
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
:param checkpoint: If set to `True`, activation checkpoint is used, defaults to `False`
|
||||
:type checkpoint: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(
|
||||
num_attention_heads, self.tesseract_dim) # *
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.query_key_value = Linear2p5D(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(attention_dropout_prob)
|
||||
self.dense = Linear2p5D(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
dtype=dtype,
|
||||
init_weight=weight_init,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(
|
||||
query_key_value, 3, dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(
|
||||
query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / \
|
||||
math.sqrt(self.attention_head_size)
|
||||
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[
|
||||
:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead2p5D(ParallelLayer):
|
||||
"""Output layer for 2.5D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=None,
|
||||
weight_init='torch'
|
||||
):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
assert weight_init in ('torch', 'jax')
|
||||
if weight_init == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
else:
|
||||
self.init_weight = weight_init
|
||||
self.init_bias = weight_init
|
||||
|
||||
self.linear = Linear2p5D(
|
||||
hidden_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding2p5D(ParallelLayer):
|
||||
""" 2.5D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
in_chans=3,
|
||||
flatten=True,
|
||||
weight_init='torch'):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_dim = embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2) # *
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.proj = nn.Conv2d(in_chans,
|
||||
self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
device=get_current_device()
|
||||
)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
if weight_init == 'jax':
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
std = math.sqrt(1.0 / fan_in)
|
||||
nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
nn.init.zeros_(self.proj.bias)
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.proj.bias, num_partition)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTInputSplitter2p5D(ParallelLayer):
|
||||
"""Split the input tensor for 2D parallel Vision Transformer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = AllGatherLast.apply(
|
||||
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
x = SplitFirst.apply(
|
||||
x, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTTokenFuser2p5D(ParallelLayer):
|
||||
"""
|
||||
Fuse cls token and pos embedding to the input
|
||||
|
||||
:param img_size: image size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param drop_rate: dropout probability, defaults to 0.
|
||||
:type drop_rate: float, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size,
|
||||
patch_size,
|
||||
embed_dim,
|
||||
drop_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(
|
||||
(1, 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
self.pos_embed = nn.Parameter(torch.empty(
|
||||
(1, self.num_patches + 1, self.embed_dim // (self.tesseract_dep * self.tesseract_dim ** 2)),
|
||||
device=get_current_device()))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, num_partition)
|
||||
|
||||
def _broadcast_params(self, param) -> None:
|
||||
" broadcast to all column ranks for data consistency "
|
||||
if self.tesseract_dep > 1:
|
||||
xz_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
xz_group = gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
dist.broadcast(param, src=xz_rank[0],
|
||||
group=xz_group)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2P5D_XZ))
|
||||
grad = grad / self.tesseract_dim # / self.tesseract_dep # *
|
||||
return grad
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_token = AllGatherLast.apply(
|
||||
self.cls_token, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
cls_token = cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
pos_embed = AllGatherLast.apply(
|
||||
self.pos_embed, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
x = x + pos_embed
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.pos_drop(x)
|
||||
return x
|
@ -1,17 +1,23 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter, init as init
|
||||
|
||||
from colossalai.context import seed, ParallelMode
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import Matmul_AB_2p5D, Add_Bias_2p5D, _LayerNorm_2p5D
|
||||
from ._utils import get_tesseract_dim_dep_from_env, assert_tesseract_initialization
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
|
||||
split_batch_2p5d)
|
||||
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@ -27,16 +33,14 @@ class Linear2p5D(ParallelLayer):
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype=None,
|
||||
dtype: dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
init_weight='torch',
|
||||
init_bias='torch'
|
||||
):
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
@ -52,76 +56,48 @@ class Linear2p5D(ParallelLayer):
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
|
||||
self.hidden_size_per_partition = divide(
|
||||
out_features, self.tesseract_dim)
|
||||
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.input_size_per_partition,
|
||||
self.hidden_size_per_partition,
|
||||
**factory_kwargs))
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
|
||||
|
||||
# create bias, shape: [h/q]
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.hidden_size_per_partition,
|
||||
**factory_kwargs))
|
||||
self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
# initialize parameters
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
assert init_weight in ('torch', 'jax', 'zero')
|
||||
assert init_bias in ('torch', 'jax', 'zero')
|
||||
# setting
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
# init weight
|
||||
if init_weight == 'torch':
|
||||
a = math.sqrt(5)
|
||||
nonlinearity = 'leaky_relu'
|
||||
std = init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
|
||||
bound = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -bound, bound)
|
||||
elif init_weight == 'jax':
|
||||
std = math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
a = math.sqrt(3.0) * std
|
||||
init.uniform_(self.weight, -a, a)
|
||||
elif init_weight == 'zero':
|
||||
init.zeros_(self.weight)
|
||||
|
||||
# init bias
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
if init_bias == 'torch':
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
elif init_bias == 'jax':
|
||||
init.normal_(self.bias, std=1e-6)
|
||||
elif init_bias == 'zero':
|
||||
init.zeros_(self.bias)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
# output: [m/dq, n/q, h/q]
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition,)
|
||||
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
|
||||
|
||||
output = Matmul_AB_2p5D.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.tesseract_dim,
|
||||
out_shape,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
self.row_rank,
|
||||
self.col_rank,
|
||||
self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
self.data_parallel_rank,
|
||||
@ -132,34 +108,17 @@ class Linear2p5D(ParallelLayer):
|
||||
|
||||
if self.bias is not None:
|
||||
if self.skip_bias_add:
|
||||
bias = Add_Bias_2p5D.apply(
|
||||
None,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
bias = Add_Bias_2p5D.apply(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL,
|
||||
True, self.data_parallel_rank, self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size, self.tensor_parallel_size)
|
||||
return output, bias
|
||||
else:
|
||||
output = Add_Bias_2p5D.apply(
|
||||
output,
|
||||
self.bias,
|
||||
self.hidden_size_per_partition,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
False,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
output = Add_Bias_2p5D.apply(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_COL, False, self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
return output
|
||||
else:
|
||||
return output
|
||||
@ -179,12 +138,7 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
:param dtype: The dtype of parameters, defaults to None
|
||||
:type dtype: torch.dtype, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
normalized_shape: int,
|
||||
eps: float = 1e-05,
|
||||
dtype=None
|
||||
):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-05, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
# layer norm config
|
||||
@ -199,66 +153,251 @@ class LayerNorm2p5D(ParallelLayer):
|
||||
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.partitioned_partition = divide(
|
||||
normalized_shape, self.tesseract_dim) # *
|
||||
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
|
||||
self.gamma = Parameter(torch.ones(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(
|
||||
self.partitioned_partition,
|
||||
**factory_kwargs))
|
||||
self.gamma = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
self.beta = Parameter(torch.zeros(self.partitioned_partition, **factory_kwargs))
|
||||
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, num_partition)
|
||||
set_tensor_parallel_attribute_by_partition(self.gamma, self.tesseract_dim)
|
||||
set_tensor_parallel_attribute_by_partition(self.beta, self.tesseract_dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
with torch.no_grad():
|
||||
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(
|
||||
E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
E_x /= self.normalized_shape
|
||||
|
||||
# Var_x in the block below is the sum of input^2
|
||||
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
|
||||
torch.distributed.all_reduce(
|
||||
Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
Var_x /= self.normalized_shape
|
||||
|
||||
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
|
||||
# this time 1/sqrt(Var_x + epsilon)
|
||||
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
|
||||
|
||||
output = _LayerNorm_2p5D.apply(x, E_x, Var_x, self.normalized_shape,
|
||||
ParallelMode.PARALLEL_2P5D_ROW)
|
||||
bias = Add_Bias_2p5D.apply(
|
||||
None, self.beta, self.partitioned_partition,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
scale = Add_Bias_2p5D.apply(
|
||||
None, self.gamma, self.partitioned_partition,
|
||||
self.tesseract_dim,
|
||||
self.row_rank, self.col_rank, self.dep_rank,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
True,
|
||||
self.data_parallel_rank,
|
||||
self.pipeline_parallel_rank,
|
||||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size
|
||||
)
|
||||
output = layernorm_2p5d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
|
||||
bias = Add_Bias_2p5D.apply(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
scale = Add_Bias_2p5D.apply(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank,
|
||||
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
output = torch.addcmul(bias, scale, output)
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class PatchEmbedding2p5D(ParallelLayer):
|
||||
""" 2D Image to Patch Embedding
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param embed_dim: dimension of embedding
|
||||
:type embed_dim: int
|
||||
:param in_chans: number of channels of input image, defaults to 3
|
||||
:type in_chans: int, optional
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
fan_out = self.embed_size
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = all_gather_weight_2p5d.apply(self.cls_token, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
pos_embed = all_gather_weight_2p5d.apply(self.pos_embed, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
cls_token = cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
output = output + pos_embed
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Embedding2p5D(ParallelLayer):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
assert_tesseract_initialization()
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Classifier2p5D(ParallelLayer):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
assert_tesseract_initialization()
|
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
col_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_COL)[0]
|
||||
row_src_rank = gpc.get_ranks_in_group(ParallelMode.PARALLEL_2P5D_ROW)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
out_shape = input_.shape[:-1] + (self.num_classes, )
|
||||
|
||||
return classifier_2p5d.apply(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
|
||||
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
|
||||
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size)
|
||||
|
@ -1,9 +1,6 @@
|
||||
from ._operation import Matmul_ABT_3D, Matmul_ATB_3D, Matmul_AB_3D, Mul_3D, Sum_3D, Add_3D, Reduce_3D
|
||||
from ._vit import ViTHead3D, ViTMLP3D, ViTPatchEmbedding3D, ViTSelfAttention3D
|
||||
from .layers import Linear3D, LayerNorm3D
|
||||
from ._operation import reduce_by_batch_3d, split_batch_3d
|
||||
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
|
||||
|
||||
__all__ = [
|
||||
'Matmul_ABT_3D', 'Matmul_ATB_3D', 'Matmul_AB_3D', 'Mul_3D', 'Sum_3D', 'Add_3D', 'Reduce_3D',
|
||||
'ViTHead3D', 'ViTMLP3D', 'ViTPatchEmbedding3D', 'ViTSelfAttention3D',
|
||||
'Linear3D', 'LayerNorm3D'
|
||||
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||
]
|
||||
|
@ -1,11 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.communication import all_gather, all_reduce, reduce_scatter, broadcast, reduce
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
@ -15,7 +14,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
class linear_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
@ -25,33 +24,16 @@ class linear_3d(torch.autograd.Function):
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
assert input_.shape[-1] == weight.shape[0], \
|
||||
'Invalid shapes: input = {}, weight = {}.'.format(input_.shape, weight.shape)
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||
input_ = torch.cat(input_, dim=input_dim)
|
||||
# weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
# ranks_in_group = gpc.get_ranks_in_group(output_parallel_mode)
|
||||
# src_rank = ranks_in_group[gpc.get_local_rank(input_parallel_mode)]
|
||||
# dist.broadcast(bias,
|
||||
# src=src_rank,
|
||||
# group=gpc.get_group(output_parallel_mode))
|
||||
# bias = all_gather(bias, -1, weight_parallel_mode)
|
||||
output += bias
|
||||
# ctx.src_rank = src_rank
|
||||
|
||||
# ctx.save_for_backward(input_, weight)
|
||||
# output = torch.matmul(input_, weight)
|
||||
# dist.all_reduce(output, group=gpc.get_group(output_parallel_mode))
|
||||
# output += bias
|
||||
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
@ -63,115 +45,105 @@ class linear_3d(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
# input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
# dist.all_reduce(input_grad,
|
||||
# group=gpc.get_group(ctx.input_parallel_mode))
|
||||
# weight_grad = torch.matmul(
|
||||
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
# output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
# dist.all_reduce(weight_grad,
|
||||
# group=gpc.get_group(ctx.weight_parallel_mode))
|
||||
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
|
||||
|
||||
# bias_grad = torch.sum(output_grad,
|
||||
# dim=tuple(
|
||||
# range(len(output_grad.shape))[:-1]))
|
||||
# bias_grad = reduce_scatter(bias_grad, -1,
|
||||
# ctx.weight_parallel_mode)
|
||||
# dist.reduce(bias_grad,
|
||||
# dst=ctx.src_rank,
|
||||
# group=gpc.get_group(ctx.output_parallel_mode))
|
||||
# if gpc.get_local_rank(
|
||||
# ctx.output_parallel_mode) != gpc.get_local_rank(
|
||||
# ctx.input_parallel_mode):
|
||||
# bias_grad = None
|
||||
|
||||
# input_ = all_gather(input_, ctx.input_dim, ctx.input_parallel_mode)
|
||||
# weight = all_gather(weight, ctx.weight_dim,
|
||||
# ctx.weight_parallel_mode)
|
||||
|
||||
output_grad = all_gather(output_grad, ctx.output_dim,
|
||||
ctx.output_parallel_mode)
|
||||
output_grad = torch.cat(output_grad, dim=ctx.output_dim)
|
||||
async_ops = list()
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
|
||||
input_grad, input_op = reduce_scatter(input_grad, ctx.input_dim,
|
||||
ctx.input_parallel_mode,
|
||||
async_op=True)
|
||||
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
|
||||
# weight_grad = torch.matmul(
|
||||
# input_.reshape(-1, input_.shape[-1]).transpose(0, 1),
|
||||
# output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
# weight_grad = reduce_scatter(weight_grad, ctx.weight_dim,
|
||||
# ctx.weight_parallel_mode)
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad,
|
||||
dim=tuple(
|
||||
range(len(output_grad.shape))[:-1]))
|
||||
# bias_grad =all_reduce(bias_grad, ctx.output_parallel_mode)
|
||||
# dist.all_reduce(bias_grad,
|
||||
# group=gpc.get_group(ctx.weight_parallel_mode))
|
||||
weight_grad = torch.cat([weight_grad, torch.unsqueeze(bias_grad, dim=0)])
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
|
||||
weight_grad, weight_op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
|
||||
input_op.wait()
|
||||
weight_op.wait()
|
||||
if ctx.use_bias:
|
||||
bias_grad = weight_grad[-1]
|
||||
weight_grad = weight_grad[:-1]
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class layer_norm_3d(torch.autograd.Function):
|
||||
class classifier_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, weight: Tensor, bias: Tensor,
|
||||
normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
weight = broadcast(weight, src_rank, input_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight.transpose(0, 1))
|
||||
output = all_reduce(output, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
ctx.src_rank = src_rank
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
async_ops = list()
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
else:
|
||||
weight_grad = None
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight)
|
||||
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
|
||||
|
||||
class layernorm_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
# mean = torch.sum(input_, dim=-1)
|
||||
# dist.all_reduce(mean, group=gpc.get_group(output_parallel_mode))
|
||||
# mean /= normalized_shape
|
||||
# mu = input_ - mean
|
||||
# var = torch.sum(torch.pow(mu, 2), dim=-1)
|
||||
# dist.all_reduce(var, group=gpc.get_group(output_parallel_mode))
|
||||
# var /= normalized_shape
|
||||
# std_dev = torch.sqrt(var + eps)
|
||||
# ctx.save_for_backward(input_, mu, std_dev, weight)
|
||||
|
||||
# output = weight * mu / std_dev + bias
|
||||
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True),
|
||||
output_parallel_mode) / normalized_shape
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
mu = input_ - mean
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True),
|
||||
output_parallel_mode) / normalized_shape
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
sigma = torch.sqrt(var + eps)
|
||||
|
||||
# ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
# src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
# transforms = torch.stack([weight, bias]).contiguous()
|
||||
# dist.broadcast(transforms,
|
||||
# src=src_rank,
|
||||
# group=gpc.get_group(input_parallel_mode))
|
||||
# transforms = all_gather(transforms, -1, weight_parallel_mode)
|
||||
# weight, bias = transforms[0], transforms[1]
|
||||
|
||||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z + bias
|
||||
|
||||
# ctx.src_rank = src_rank
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
@ -181,7 +153,7 @@ class layer_norm_3d(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
@ -191,373 +163,63 @@ class layer_norm_3d(torch.autograd.Function):
|
||||
grads = all_reduce(grads, ctx.input_parallel_mode)
|
||||
bias_grad, weight_grad = grads[0], grads[1]
|
||||
|
||||
# grads = reduce_scatter(grads, -1, ctx.weight_parallel_mode)
|
||||
# dist.reduce(grads,
|
||||
# dst=ctx.src_rank,
|
||||
# group=gpc.get_group(ctx.input_parallel_mode))
|
||||
# if gpc.get_local_rank(
|
||||
# ctx.input_parallel_mode) == gpc.get_local_rank(
|
||||
# ctx.output_parallel_mode):
|
||||
# bias_grad, weight_grad = grads[0], grads[1]
|
||||
# else:
|
||||
# bias_grad, weight_grad = None, None
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
|
||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
input_grad = dz / sigma + dvar * 2 * mu / \
|
||||
ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_AB_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB`
|
||||
"""
|
||||
def split_batch_3d(input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
dim: int = 0) -> Tensor:
|
||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
||||
dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
|
||||
|
||||
class reduce_by_batch_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
# A: [m/q^2, n, k/q]
|
||||
# B: [k/q, h/q^2]
|
||||
# C: [m/q^2, n, h/q]
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
assert A.shape[-1] == B.shape[0], \
|
||||
'Invalid shapes: A={}, B={}.'.format(A.shape, B.shape)
|
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||||
|
||||
C = torch.matmul(A_temp, B_temp)
|
||||
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.A_group_parallel_mode = input_parallel_mode
|
||||
ctx.B_group_parallel_mode = weight_parallel_mode
|
||||
ctx.C_group_parallel_mode = output_parallel_mode
|
||||
ctx.A_dim = input_dim
|
||||
ctx.B_dim = weight_dim
|
||||
ctx.C_dim = output_dim
|
||||
|
||||
return out
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
|
||||
output = all_reduce(input_, input_parallel_mode)
|
||||
output = all_reduce(output, weight_parallel_mode)
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_3D.apply(output_grad, B, ctx.depth,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.B_group_parallel_mode,
|
||||
ctx.A_group_parallel_mode, ctx.C_dim,
|
||||
ctx.B_dim, ctx.A_dim)
|
||||
B_grad = Matmul_ATB_3D.apply(A, output_grad, ctx.depth,
|
||||
ctx.A_group_parallel_mode,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.B_group_parallel_mode, ctx.A_dim,
|
||||
ctx.C_dim, ctx.B_dim)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_ABT_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = AB^T`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
# A: [m/q^2, n, h/q]
|
||||
# B: [k/q, h/q^2]
|
||||
# C: [m/q^2, n, k/q]
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||||
|
||||
C = torch.matmul(A_temp, B_temp.transpose(0, 1))
|
||||
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.A_group_parallel_mode = input_parallel_mode
|
||||
ctx.B_group_parallel_mode = weight_parallel_mode
|
||||
ctx.C_group_parallel_mode = output_parallel_mode
|
||||
ctx.A_dim = input_dim
|
||||
ctx.B_dim = weight_dim
|
||||
ctx.C_dim = output_dim
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_AB_3D.apply(output_grad, B, ctx.depth,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.B_group_parallel_mode,
|
||||
ctx.A_group_parallel_mode, ctx.C_dim,
|
||||
ctx.B_dim, ctx.A_dim)
|
||||
B_grad = Matmul_ATB_3D.apply(output_grad, A, ctx.depth,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.A_group_parallel_mode,
|
||||
ctx.B_group_parallel_mode, ctx.C_dim,
|
||||
ctx.A_dim, ctx.B_dim)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Matmul_ATB_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = A^TB`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = 0,
|
||||
output_dim: int = -1) -> Tensor:
|
||||
# A: [m/q^2, n, k/q]
|
||||
# B: [m/q^2, n, h/q]
|
||||
# C: [k/q, h/q^2]
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_temp = all_gather(A, input_dim, input_parallel_mode)
|
||||
A_temp = A_temp.reshape(-1, A.shape[-1])
|
||||
B_temp = all_gather(B, weight_dim, weight_parallel_mode)
|
||||
B_temp = B_temp.reshape(-1, B.shape[-1])
|
||||
|
||||
C = torch.matmul(A_temp.transpose(0, 1), B_temp)
|
||||
out = reduce_scatter(C, output_dim, output_parallel_mode)
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.A_group_parallel_mode = input_parallel_mode
|
||||
ctx.B_group_parallel_mode = weight_parallel_mode
|
||||
ctx.C_group_parallel_mode = output_parallel_mode
|
||||
ctx.A_dim = input_dim
|
||||
ctx.B_dim = weight_dim
|
||||
ctx.C_dim = output_dim
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
A, B = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
A_grad = Matmul_ABT_3D.apply(B, output_grad, ctx.depth,
|
||||
ctx.B_group_parallel_mode,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.A_group_parallel_mode, ctx.B_dim,
|
||||
ctx.C_dim, ctx.A_dim)
|
||||
B_grad = Matmul_AB_3D.apply(A, output_grad, ctx.depth,
|
||||
ctx.A_group_parallel_mode,
|
||||
ctx.C_group_parallel_mode,
|
||||
ctx.B_group_parallel_mode, ctx.A_dim,
|
||||
ctx.C_dim, ctx.B_dim)
|
||||
return A_grad, B_grad, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class Add_3D(torch.autograd.Function):
|
||||
"""Matrix add bias: :math:`C = A + b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
# input: [m/q^2, n, h/q]
|
||||
# bias: [h/q^2]
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
bias_temp = bias.clone()
|
||||
dist.broadcast(bias_temp,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(input_parallel_mode))
|
||||
# [h/q]
|
||||
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
||||
|
||||
out = input_ + bias_temp
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.src_rank = src_rank
|
||||
ctx.A_group_parallel_mode = input_parallel_mode
|
||||
ctx.B_group_parallel_mode = weight_parallel_mode
|
||||
ctx.C_group_parallel_mode = output_parallel_mode
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [m/q^2, n, h/q]
|
||||
with torch.no_grad():
|
||||
# [h/q]
|
||||
grad = torch.sum(output_grad,
|
||||
dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
||||
dist.reduce(bias_grad,
|
||||
dst=ctx.src_rank,
|
||||
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||||
if gpc.get_local_rank(
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
bias_grad = None
|
||||
return output_grad, bias_grad, None, None, None, None
|
||||
|
||||
|
||||
class Mul_3D(torch.autograd.Function):
|
||||
"""Matrix multiplication for :math:`C = A * b`
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, bias: Tensor, depth: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
# input: [m/q^2, n, h/q]
|
||||
# bias: [h/q^2]
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
# [h/q^2]
|
||||
bias_temp = bias.clone()
|
||||
dist.broadcast(bias_temp,
|
||||
src=src_rank,
|
||||
group=gpc.get_group(input_parallel_mode))
|
||||
# [h/q]
|
||||
bias_temp = all_gather(bias_temp, -1, weight_parallel_mode)
|
||||
|
||||
# empty_cache()
|
||||
ctx.save_for_backward(input_, bias_temp)
|
||||
|
||||
out = torch.mul(input_, bias_temp)
|
||||
|
||||
ctx.depth = depth
|
||||
ctx.src_rank = src_rank
|
||||
ctx.A_group_parallel_mode = input_parallel_mode
|
||||
ctx.B_group_parallel_mode = weight_parallel_mode
|
||||
ctx.C_group_parallel_mode = output_parallel_mode
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# output_grad: [m/q^2, n, h/q]
|
||||
with torch.no_grad():
|
||||
input_, bias = ctx.saved_tensors
|
||||
# [m/q^2, n, h/q]
|
||||
input_grad = torch.mul(output_grad, bias)
|
||||
# [h/q]
|
||||
grad = torch.mul(output_grad, input_)
|
||||
grad = torch.sum(grad,
|
||||
dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = reduce_scatter(grad, -1, ctx.B_group_parallel_mode)
|
||||
dist.reduce(bias_grad,
|
||||
dst=ctx.src_rank,
|
||||
group=gpc.get_group(ctx.A_group_parallel_mode))
|
||||
if gpc.get_local_rank(
|
||||
ctx.A_group_parallel_mode) != gpc.get_local_rank(
|
||||
ctx.C_group_parallel_mode):
|
||||
bias_grad = None
|
||||
return input_grad, bias_grad, None, None, None, None
|
||||
|
||||
|
||||
class Sum_3D(torch.autograd.Function):
|
||||
"""Compute the sum of input tensors
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any,
|
||||
input_: Tensor,
|
||||
dim: int,
|
||||
depth: int,
|
||||
parallel_mode: ParallelMode,
|
||||
keepdim: bool = False) -> Tensor:
|
||||
# input: [m/q^2, n, h/q]
|
||||
out = torch.sum(input_, dim=dim, keepdim=keepdim)
|
||||
dist.all_reduce(out, group=gpc.get_group(parallel_mode))
|
||||
|
||||
ctx.input_shape = input_.shape
|
||||
ctx.depth = depth
|
||||
ctx.group = parallel_mode
|
||||
ctx.dim = dim
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
with torch.no_grad():
|
||||
output_grad = output_grad.contiguous()
|
||||
dist.all_reduce(output_grad, group=gpc.get_group(ctx.group))
|
||||
if len(output_grad.shape) < len(ctx.input_shape):
|
||||
output_grad = torch.unsqueeze(output_grad, ctx.dim)
|
||||
dims = [1 for _ in range(len(output_grad.shape))]
|
||||
dims[ctx.dim] = ctx.input_shape[ctx.dim]
|
||||
input_grad = output_grad.repeat(tuple(dims))
|
||||
return input_grad, None, None, None, None, None
|
||||
|
||||
|
||||
class Reduce_3D(torch.autograd.Function):
|
||||
"""Reduce input tensors
|
||||
"""
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, input_: Tensor, depth: int,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
|
||||
return input_.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
return output_grad, None, None
|
||||
|
||||
|
||||
# class Slice_3D(torch.autograd.Function):
|
||||
# """Slice input tensor
|
||||
# """
|
||||
# @staticmethod
|
||||
# @custom_fwd(cast_inputs=torch.float16)
|
||||
# def forward(ctx: Any, input_: Tensor, dim: int, depth: int,
|
||||
# parallel_mode: ParallelMode) -> Tensor:
|
||||
# rank = gpc.get_local_rank(parallel_mode)
|
||||
# out = torch.chunk(input_, depth, dim=dim)[rank].contiguous()
|
||||
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
output = broadcast(input_, src_rank, input_parallel_mode)
|
||||
ctx.src_rank = src_rank
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
# ctx.depth = depth
|
||||
# ctx.parallel_mode = parallel_mode
|
||||
# ctx.dim = dim
|
||||
# ctx.input_shape = input_.shape
|
||||
|
||||
# return out
|
||||
|
||||
# @staticmethod
|
||||
# @custom_bwd
|
||||
# def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
# with torch.no_grad():
|
||||
# input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
|
||||
# input_grad.reshape(ctx.input_shape)
|
||||
# return input_grad, None, None, None
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
|
||||
else:
|
||||
input_grad = None
|
||||
return input_grad, None, None, None
|
||||
|
@ -1,413 +0,0 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.nn.init import init_bias_, init_weight_
|
||||
from colossalai.utils import checkpoint, get_current_device
|
||||
from torch import Tensor, dtype, nn
|
||||
|
||||
from .._common_utils import ACT2FN, divide, set_tensor_parallel_attribute_by_size, to_2tuple
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, get_last_group
|
||||
from .layers import Linear3D
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTPatchEmbedding3D(nn.Module):
|
||||
""" 3D Image to Patch Embedding
|
||||
|
||||
:param img_size: iamge size
|
||||
:type img_size: int
|
||||
:param patch_size: patch size
|
||||
:type patch_size: int
|
||||
:param in_chans: number of channels of input image
|
||||
:type in_chans: int
|
||||
:param embed_size: dimension of embedding
|
||||
:type embed_size: int
|
||||
:param drop_prob: dropout probability
|
||||
:type drop_prob: float
|
||||
:param flatten: whether to flatten output tensor, defaults to True
|
||||
:type flatten: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
drop_prob: float,
|
||||
flatten: bool = True,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.in_chans = in_chans
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = divide(self.embed_size, self.depth)
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'jax_embed'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
self.proj = nn.Conv2d(self.in_chans,
|
||||
self.embed_size_per_partition,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros(1, 1, self.embed_size_per_partition))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1,
|
||||
self.embed_size_per_partition))
|
||||
self.pos_drop = nn.Dropout(drop_prob)
|
||||
|
||||
self.reset_parameters(self.init_weight, self.init_bias)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_size(self.proj.weight, self.in_chans * self.embed_size * self.num_patches)
|
||||
set_tensor_parallel_attribute_by_size(self.proj.bias, self.embed_size)
|
||||
set_tensor_parallel_attribute_by_size(self.cls_token, 1 * 1 * self.embed_size)
|
||||
set_tensor_parallel_attribute_by_size(self.pos_embed, 1 * (self.num_patches + 1) * self.embed_size)
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.proj.weight)
|
||||
# std = math.sqrt(1.0 / fan_in)
|
||||
# nn.init.trunc_normal_(self.proj.weight, std=std / .87962566103423978)
|
||||
# nn.init.zeros_(self.proj.bias)
|
||||
if init_weight != 'torch':
|
||||
init_weight_(self.proj.weight, fan_in, init_method=init_weight)
|
||||
init_bias_(self.pos_embed, fan_in, init_method=init_weight)
|
||||
if init_bias != 'torch':
|
||||
init_bias_(self.proj.bias, fan_in, init_method=init_bias)
|
||||
|
||||
self.to(get_current_device())
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
dist.broadcast(self.proj.weight,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
dist.broadcast(self.proj.bias,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
dist.broadcast(self.proj.weight,
|
||||
src=input_src_rank,
|
||||
group=gpc.get_group(self.input_parallel_mode))
|
||||
dist.broadcast(self.proj.bias,
|
||||
src=input_src_rank,
|
||||
group=gpc.get_group(self.input_parallel_mode))
|
||||
|
||||
self.proj.weight.register_hook(self._sync_grad_hook)
|
||||
self.proj.bias.register_hook(self._sync_grad_hook)
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
dist.all_reduce(grad, group=gpc.get_group(self.input_parallel_mode))
|
||||
dist.all_reduce(grad, group=gpc.get_group(self.weight_parallel_mode))
|
||||
return grad
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# split a partition from inputs
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.weight_parallel_mode)].contiguous()
|
||||
x = torch.chunk(x, self.depth, dim=0)[gpc.get_local_rank(
|
||||
self.input_parallel_mode)].contiguous()
|
||||
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
# add cls token & pos embedding
|
||||
# [b/q^2,s,h/q] --> [b/q^2, 1+s, h/q]
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention3D(nn.Module):
|
||||
"""Self-attention layer for 3D parallel Vision Transformer
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_probs_dropout_prob: dropout probability for attention layers
|
||||
:type attention_probs_dropout_prob: bool
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: bool
|
||||
:param depth: the 3D parallelism depth
|
||||
:type depth: int
|
||||
:param input_parallel_mode: parallel mode of input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode of weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: dtype, optional
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_probs_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = divide(num_attention_heads, self.depth)
|
||||
self.attention_head_size = divide(hidden_size, num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.checkpoint = checkpoint
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'jax'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
self.query_key_value = Linear3D(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.attention_dropout = nn.Dropout(attention_probs_dropout_prob)
|
||||
self.dense = Linear3D(self.hidden_size,
|
||||
self.hidden_size,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
# return self.input_parallel_mode, self.weight_parallel_mode
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads, 3 * self.attention_head_size)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value,
|
||||
3,
|
||||
dim=-1)
|
||||
|
||||
attention_scores = torch.matmul(query_layer,
|
||||
key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(
|
||||
self.attention_head_size)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.transpose(1, 2)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.all_head_size, )
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP3D(nn.Module):
|
||||
"""[summary]
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param mlp_ratio: hidden size of MLP divided by embedding dim
|
||||
:type mlp_ratio: int
|
||||
:param hidden_dropout_prob: dropout probability for hidden layers
|
||||
:type hidden_dropout_prob: float
|
||||
:param hidden_act: activation function for hidden layers
|
||||
:type hidden_act: str
|
||||
:param depth: the 3D parallelism depth
|
||||
:type depth: int
|
||||
:param input_parallel_mode: parallel mode of input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode of weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: dtype, optional
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
mlp_ratio: int,
|
||||
hidden_dropout_prob: float,
|
||||
hidden_act: str = 'gelu',
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.hidden_size = hidden_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.checkpoint = checkpoint
|
||||
self.init_weight = init_method
|
||||
self.init_bias = init_method
|
||||
|
||||
self.dense_1 = Linear3D(self.hidden_size,
|
||||
self.mlp_ratio * self.hidden_size,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.activation_func = ACT2FN[hidden_act]
|
||||
self.dense_2 = Linear3D(self.mlp_ratio * self.hidden_size,
|
||||
self.hidden_size,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
# def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
|
||||
# return self.input_parallel_mode, self.weight_parallel_mode
|
||||
|
||||
def _forward(self, hidden_states: Tensor) -> Tensor:
|
||||
intermediate_output = self.dense_1(hidden_states)
|
||||
intermediate_output = self.activation_func(intermediate_output)
|
||||
output = self.dense_2(intermediate_output)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
|
||||
return checkpoint(self._forward, hidden_states)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(hidden_states)
|
||||
else:
|
||||
return self._forward(hidden_states)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead3D(nn.Module):
|
||||
"""Output layer for 3D parallel Vision Transformer
|
||||
|
||||
:param in_features: size of input tensor
|
||||
:type in_features: int
|
||||
:param num_classes: number of classes
|
||||
:type num_classes: int
|
||||
:param depth: the 3D parallelism depth
|
||||
:type depth: int
|
||||
:param input_parallel_mode: parallel mode of input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode of weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param dtype: dtype of parameters, defaults to None
|
||||
:type dtype: dtype, optional
|
||||
:param bias: whether to add bias, defaults to True
|
||||
:type bias: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
init_method: str = 'torch'):
|
||||
super().__init__()
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
# self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
# self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
# self.weight_parallel_mode)
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
# out_features = math.ceil(self.num_classes /
|
||||
# (self.depth**2)) * (self.depth**2)
|
||||
# self.num_classes_per_partition = divide(self.num_classes, self.depth)
|
||||
self.init_weight = 'torch'
|
||||
self.init_bias = 'torch'
|
||||
if init_method == 'jax':
|
||||
self.init_weight = 'zero'
|
||||
self.init_bias = 'zero'
|
||||
|
||||
self.linear = Linear3D(self.in_features,
|
||||
self.num_classes,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_weight=self.init_weight,
|
||||
init_bias=self.init_bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# [b/q^2, s, h/q] --> [b/q^2, h/q]
|
||||
x = x[:, 0]
|
||||
# [b/q^2, h/q] --> [b/q^2, c/q]
|
||||
x = self.linear(x)
|
||||
# return x[:, :self.num_classes_per_partition]
|
||||
return x
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, num_classes={}'.format(self.in_features,
|
||||
self.num_classes)
|
@ -1,191 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import all_reduce, broadcast
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.init import init_bias_, init_weight_
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import init as init
|
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_size
|
||||
from ._operation import (Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D, layer_norm_3d,
|
||||
linear_3d)
|
||||
from ._utils import (get_depth_from_env, get_last_group,
|
||||
get_parallel_mode_from_env, swap_in_out_group)
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ._operation import *
|
||||
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class LayerNorm3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: int,
|
||||
# input_parallel_mode: ParallelMode,
|
||||
# weight_parallel_mode: ParallelMode,
|
||||
eps: float = 1e-12,
|
||||
dtype: dtype = None,
|
||||
):
|
||||
class LayerNorm3D(ParallelLayer):
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition, device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.variance_epsilon = eps
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_size(self.weight, self.normalized_shape)
|
||||
set_tensor_parallel_attribute_by_size(self.bias, self.normalized_shape)
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def reset_parameters(self):
|
||||
init.zeros_(self.bias)
|
||||
init.ones_(self.weight)
|
||||
def reset_parameters(self) -> None:
|
||||
init.zeros_()(self.bias)
|
||||
init.ones_()(self.weight)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# '''x = weight * (x - mean) / sqrt(var + eps) + bias'''
|
||||
# # input: [m/q^2, n, h/q]
|
||||
# # [m/q^2, n, 1]
|
||||
# mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
|
||||
# True) / self.normalized_shape
|
||||
# # [m/q^2, n, 1]
|
||||
# var = (input_ - mean).pow(2)
|
||||
# var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
|
||||
# True) / self.normalized_shape
|
||||
|
||||
# output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
|
||||
# output = Mul_3D.apply(output, self.weight, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
# output = Add_3D.apply(output, self.bias, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
# return output
|
||||
return layer_norm_3d.apply(input_, self.weight, self.bias,
|
||||
self.normalized_shape,
|
||||
self.variance_epsilon,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
||||
def extra_repr(self):
|
||||
return '{}, eps={}'.format(self.normalized_shape,
|
||||
self.variance_epsilon)
|
||||
return layernorm_3d.apply(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
||||
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
# input_parallel_mode: ParallelMode,
|
||||
# weight_parallel_mode: ParallelMode,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
init_weight: str = 'torch',
|
||||
init_bias: str = 'torch'):
|
||||
class Linear3D(ParallelLayer):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
# self.with_bias = bias
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth)
|
||||
|
||||
# [k/q, h/q]
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
# [h/q]
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
|
||||
dtype=dtype))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.bias = None
|
||||
|
||||
self.reset_parameters(init_weight, init_bias)
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
swap_in_out_group()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_size(self.weight, self.in_features * self.out_features)
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_size(self.bias, self.out_features)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def reset_parameters(self, init_weight, init_bias) -> None:
|
||||
# setting
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
|
||||
# init weight
|
||||
init_weight_(self.weight, fan_in, fan_out, init_method=init_weight)
|
||||
dist.broadcast(self.weight,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
# init bias
|
||||
if self.bias is not None:
|
||||
init_bias_(self.bias, fan_in, init_method=init_bias)
|
||||
dist.broadcast(self.bias,
|
||||
src=weight_src_rank,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
dist.broadcast(self.bias,
|
||||
src=output_src_rank,
|
||||
group=gpc.get_group(self.output_parallel_mode))
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# # input: [m/q^2, n, k/q]
|
||||
# # output: [m/q^2, n, h/q]
|
||||
# output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
|
||||
# self.input_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.output_parallel_mode)
|
||||
|
||||
# if self.bias is not None:
|
||||
# output = Add_3D.apply(output, self.bias, self.depth,
|
||||
# self.output_parallel_mode,
|
||||
# self.weight_parallel_mode,
|
||||
# self.input_parallel_mode)
|
||||
# return output
|
||||
return linear_3d.apply(input_, self.weight, self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
return linear_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'in_features={}, out_features={}, bias={}'.format(
|
||||
self.in_features, self.out_features, self.with_bias)
|
||||
|
||||
@LAYERS.register_module
|
||||
class Classifier3D(ParallelLayer):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
if self.has_weight:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return classifier_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class PatchEmbedding3D(ParallelLayer):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
grid_size = to_2tuple(img_size // patch_size)
|
||||
num_patches = grid_size[0] * grid_size[1]
|
||||
self.embed_size = embed_size
|
||||
embed_size_per_partition = divide(embed_size, self.depth)
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> None:
|
||||
grad = all_reduce(grad, self.input_parallel_mode)
|
||||
grad = all_reduce(grad, self.weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
fan_out = self.embed_size
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
|
||||
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
output = output + self.pos_embed
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Embedding3D(ParallelLayer):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, self.depth)
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
|
3
colossalai/nn/layer/vanilla/__init__.py
Normal file
3
colossalai/nn/layer/vanilla/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding
|
||||
|
||||
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath']
|
134
colossalai/nn/layer/vanilla/layers.py
Normal file
134
colossalai/nn/layer/vanilla/layers.py
Normal file
@ -0,0 +1,134 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch import nn as nn
|
||||
|
||||
from .._common_utils import to_2tuple
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaPatchEmbedding(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
B, C, H, W = input_.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
output = output + self.pos_embed
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VanillaClassifier(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return F.linear(input_, self.weight, self.bias)
|
@ -1,5 +1,26 @@
|
||||
from .cross_entropy_2d import CrossEntropyLoss2D
|
||||
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
|
||||
from .cross_entropy_3d import CrossEntropyLoss3D
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import *
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D']
|
||||
from .loss_2d import CrossEntropyLoss2D
|
||||
from .loss_2p5d import CrossEntropyLoss2p5D
|
||||
from .loss_3d import CrossEntropyLoss3D
|
||||
|
||||
_parallel_cross_entropy = {
|
||||
'2d': CrossEntropyLoss2D,
|
||||
'2.5d': CrossEntropyLoss2p5D,
|
||||
'3d': CrossEntropyLoss3D
|
||||
}
|
||||
|
||||
|
||||
class CrossEntropyLoss(_Loss):
|
||||
def __init__(self, reduction: bool = True, tensor_parallel: str = None, *args, **kwargs):
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
reduction = 'mean' if reduction else 'none'
|
||||
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
|
||||
else:
|
||||
self.loss = _parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
|
||||
|
||||
def forward(self, *args):
|
||||
return self.loss(*args)
|
||||
|
@ -1,131 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
|
||||
### Modified based on megatron.mpu.cross_entropy ###
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, logits, targets):
|
||||
# logits: [b/q, h/q]
|
||||
# labels: [b/q]
|
||||
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(
|
||||
logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
|
||||
# Subtract the maximum value.
|
||||
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
logits = logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
vocab_start = rank * (vocab_size)
|
||||
vocab_end = (rank + 1) * (vocab_size) - 1
|
||||
|
||||
target_mask = (targets < vocab_start) | (targets > vocab_end)
|
||||
|
||||
masked_target = targets.clone() - vocab_start
|
||||
masked_target[target_mask] = 0
|
||||
arange_1d = torch.arange(
|
||||
start=0, end=logits.size()[0],
|
||||
)
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits[target_mask] = 0.
|
||||
dist.all_reduce(predicted_logits, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_ROW))
|
||||
|
||||
exp_logits = torch.exp(logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=1)
|
||||
dist.all_reduce(sum_exp_logits, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_ROW))
|
||||
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
|
||||
device=get_current_device())
|
||||
grad_2d[arange_1d,
|
||||
masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class _ReduceByColumn(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(
|
||||
ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss2D(_Loss):
|
||||
"""Cross entropy loss for 2D parallelism
|
||||
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.summa_dim = get_summa_dim_from_env()
|
||||
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
self.reduction_mean = reduction
|
||||
|
||||
def forward(self, logits, targets):
|
||||
targets = targets.chunk(self.summa_dim, dim=0)[self.row_rank]
|
||||
loss = _ParallelCrossEntropyLossFunction_2D.apply(
|
||||
logits, targets,
|
||||
)
|
||||
if self.reduction_mean:
|
||||
loss = _ReduceByColumn.apply(loss) / self.summa_dim
|
||||
dist_loss = loss.mean()
|
||||
|
||||
return dist_loss
|
@ -1,124 +0,0 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization, \
|
||||
get_tesseract_dim_dep_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_2p5D(torch.autograd.Function):
|
||||
### Modified based on megatron.mpu.cross_entropy ###
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, logits, targets):
|
||||
# logits: [b/dq, h/q]
|
||||
# loss: [b/dq]
|
||||
# targets: [b/dq, h/q]
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(
|
||||
logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
# Subtract the maximum value.
|
||||
logits = logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
vocab_start = rank * (vocab_size)
|
||||
vocab_end = (rank + 1) * (vocab_size) - 1
|
||||
|
||||
target_mask = (targets < vocab_start) | (targets > vocab_end)
|
||||
|
||||
masked_target = targets.clone() - vocab_start
|
||||
masked_target[target_mask] = 0
|
||||
arange_1d = torch.arange(
|
||||
start=0, end=logits.size()[0],
|
||||
)
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits[target_mask] = 0.
|
||||
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
|
||||
exp_logits = torch.exp(logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=1)
|
||||
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
|
||||
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as their gradient.
|
||||
grad_input = softmax
|
||||
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
|
||||
device=get_current_device())
|
||||
grad_2d[arange_1d,
|
||||
masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
||||
return grad_input, None
|
||||
|
||||
|
||||
class _ReduceByColDep(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_XZ))
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss2p5D(_Loss):
|
||||
"""Cross entropy loss for 2.5D parallelism
|
||||
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self, reduction=True):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.xz_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.reduction_mean = reduction
|
||||
|
||||
def forward(self, logits, targets):
|
||||
targets = targets.chunk(self.tesseract_dim *
|
||||
self.tesseract_dep, dim=0)[self.xz_rank]
|
||||
loss = _ParallelCrossEntropyLossFunction_2p5D.apply(
|
||||
logits, targets,
|
||||
)
|
||||
if self.reduction_mean:
|
||||
loss = _ReduceByColDep.apply(
|
||||
loss) / self.tesseract_dim / self.tesseract_dep
|
||||
dist_loss = loss.mean()
|
||||
|
||||
return dist_loss
|
@ -1,183 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_3d._operation import Reduce_3D
|
||||
from colossalai.nn.layer.parallel_3d._utils import (get_depth_from_env,
|
||||
get_last_group,
|
||||
get_parallel_mode_from_env)
|
||||
from colossalai.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class _ParallelCrossEntropyLossFunction_3D(torch.autograd.Function):
|
||||
"""
|
||||
Adapted from megatron.mpu.cross_entropy
|
||||
loss[i] = -logits[i][targets] + log(sum(exp(logits[i])))
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, logits, targets, depth, output_parallel_mode):
|
||||
# logits: [b/q^2, c/q]
|
||||
# labels: [b/q^2]
|
||||
# loss: [b/q^2]
|
||||
logits_max = torch.max(logits, dim=-1)[0]
|
||||
dist.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(output_parallel_mode))
|
||||
# Subtract the maximum value.
|
||||
logits = logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
vocab_size_per_partition = logits.size()[-1]
|
||||
rank = gpc.get_local_rank(output_parallel_mode)
|
||||
vocab_start = rank * vocab_size_per_partition
|
||||
vocab_end = (rank + 1) * vocab_size_per_partition - 1
|
||||
|
||||
# loss[i] = 0 if targets[i] < vocab_start or targets[i] > vocab_end
|
||||
target_mask = (targets < vocab_start) | (targets > vocab_end)
|
||||
masked_target = targets.clone() - vocab_start
|
||||
masked_target[target_mask] = 0
|
||||
arange_1d = torch.arange(start=0,
|
||||
end=logits.size()[0],
|
||||
device=get_current_device())
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits = predicted_logits.clone().contiguous().view_as(
|
||||
targets)
|
||||
predicted_logits[target_mask] = 0.
|
||||
dist.all_reduce(predicted_logits,
|
||||
group=gpc.get_group(output_parallel_mode))
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
exp_logits = torch.exp(logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
dist.all_reduce(sum_exp_logits,
|
||||
group=gpc.get_group(output_parallel_mode))
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, target_mask, masked_target)
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
# Retreive tensors from the forward path.
|
||||
softmax, target_mask, masked_target = ctx.saved_tensors
|
||||
|
||||
# All the inputs have softmax as thier gradient.
|
||||
input_grad = softmax
|
||||
# For simplicity, work with the 2D gradient.
|
||||
partition_vocab_size = softmax.size()[-1]
|
||||
grad_2d = input_grad.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0,
|
||||
end=grad_2d.size()[0],
|
||||
device=get_current_device())
|
||||
grad_2d[arange_1d,
|
||||
masked_target] -= (1.0 - target_mask.view(-1).float())
|
||||
input_grad.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
||||
return input_grad, None, None, None
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss3D(_Loss):
|
||||
"""Cross entropy loss for 3D parallelism
|
||||
|
||||
:param depth: depth for 3D parallelism
|
||||
:type depth: int
|
||||
:param input_parallel_mode: parallel mode for input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode for weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
# input_parallel_mode,
|
||||
# weight_parallel_mode,
|
||||
reduction=True,
|
||||
label_smoothing=0.0):
|
||||
super().__init__()
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
self.input_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||
self.weight_rank = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
self.reduction_mean = reduction
|
||||
|
||||
def forward(self, logits, targets):
|
||||
# split label partition from the entire batch
|
||||
batch_size = targets.size(0)
|
||||
targets = torch.chunk(targets, self.depth, dim=0)[self.weight_rank]
|
||||
targets = torch.chunk(targets, self.depth, dim=0)[self.input_rank]
|
||||
loss = _ParallelCrossEntropyLossFunction_3D.apply(
|
||||
logits, targets, self.depth, self.output_parallel_mode)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
|
||||
loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
|
||||
loss /= batch_size
|
||||
return loss
|
||||
|
||||
|
||||
# @LOSSES.register_module
|
||||
# class LabelSmoothingCrossEntropy3D(_Loss):
|
||||
# """
|
||||
# NLL loss with label smoothing, adapted from timm.loss.LabelSmoothingCrossEntropy
|
||||
|
||||
# :param input_parallel_mode: parallel mode for input tensor
|
||||
# :type input_parallel_mode: ParallelMode
|
||||
# :param weight_parallel_mode: parallel mode for weight
|
||||
# :type weight_parallel_mode: ParallelMode
|
||||
# :param smoothing: label smoothing value, defaults to 0.1
|
||||
# :type smoothing: float
|
||||
# :param reduction: whether to average the loss, defaults to True
|
||||
# :type reduction: bool, optional
|
||||
# """
|
||||
# def __init__(self,
|
||||
# input_parallel_mode,
|
||||
# weight_parallel_mode,
|
||||
# smoothing=0.1,
|
||||
# reduction=True):
|
||||
# super().__init__()
|
||||
# assert smoothing < 1.0
|
||||
# self.smoothing = smoothing
|
||||
# self.confidence = 1. - smoothing
|
||||
# self.depth = get_depth_from_env()
|
||||
# self.input_parallel_mode = input_parallel_mode
|
||||
# self.weight_parallel_mode = weight_parallel_mode
|
||||
# self.output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
# weight_parallel_mode)
|
||||
# self.reduction_mean = reduction
|
||||
|
||||
# def forward(self, logits, targets):
|
||||
# # split label partition from the entire batch
|
||||
# j = gpc.get_local_rank(self.input_parallel_mode)
|
||||
# i = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
# targets = torch.chunk(targets, self.depth, dim=0)[i]
|
||||
# targets = torch.chunk(targets, self.depth, dim=0)[j]
|
||||
# exp_logits = torch.exp(logits)
|
||||
# sum_exp_logits = Sum3D.apply(exp_logits, -1, depth,
|
||||
# self.output_parallel_mode, False)
|
||||
# log_probs = torch.log(sum_exp_logits) - logits
|
||||
# nll_loss = _ParallelCrossEntropyLossFunction_3D.apply(
|
||||
# logits, targets, self.depth, self.output_parallel_mode)
|
||||
# smooth_loss = -log_probs.mean(dim=-1)
|
||||
# loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
||||
# if self.reduction_mean:
|
||||
# loss = loss.sum()
|
||||
# loss = Reduce_3D.apply(loss, self.depth, self.input_parallel_mode)
|
||||
# loss = Reduce_3D.apply(loss, self.depth, self.weight_parallel_mode)
|
||||
# loss /= batch_size
|
||||
# return loss
|
30
colossalai/nn/loss/loss_2d.py
Normal file
30
colossalai/nn/loss/loss_2d.py
Normal file
@ -0,0 +1,30 @@
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss2D(_Loss):
|
||||
"""Cross entropy loss for 2D parallelism
|
||||
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(self, reduction=True, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert_summa_initialization()
|
||||
self.reduction_mean = reduction
|
||||
self.loss_args = args
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2d.apply(loss)
|
||||
loss /= batch_size
|
||||
return loss
|
29
colossalai/nn/loss/loss_2p5d.py
Normal file
29
colossalai/nn/loss/loss_2p5d.py
Normal file
@ -0,0 +1,29 @@
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss2p5D(_Loss):
|
||||
"""Cross entropy loss for 2.5D parallelism
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(self, reduction=True, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert_tesseract_initialization()
|
||||
self.reduction_mean = reduction
|
||||
self.loss_args = args
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_2p5d(targets)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_2p5d.apply(loss)
|
||||
loss /= batch_size
|
||||
return loss
|
38
colossalai/nn/loss/loss_3d.py
Normal file
38
colossalai/nn/loss/loss_3d.py
Normal file
@ -0,0 +1,38 @@
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
class CrossEntropyLoss3D(_Loss):
|
||||
"""Cross entropy loss for 3D parallelism
|
||||
|
||||
:param depth: depth for 3D parallelism
|
||||
:type depth: int
|
||||
:param input_parallel_mode: parallel mode for input tensor
|
||||
:type input_parallel_mode: ParallelMode
|
||||
:param weight_parallel_mode: parallel mode for weight
|
||||
:type weight_parallel_mode: ParallelMode
|
||||
:param reduction: whether to average the loss, defaults to True
|
||||
:type reduction: bool, optional
|
||||
"""
|
||||
def __init__(self, reduction=True, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.reduction_mean = reduction
|
||||
self.loss_args = args
|
||||
self.loss_kwargs = kwargs
|
||||
|
||||
def forward(self, logits, targets):
|
||||
batch_size = targets.size(0)
|
||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss = cross_entropy(logits, targets, reduction='sum', *self.loss_args, **self.loss_kwargs)
|
||||
if self.reduction_mean:
|
||||
loss = loss.sum()
|
||||
loss = reduce_by_batch_3d.apply(loss, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
loss /= batch_size
|
||||
return loss
|
24
colossalai/nn/metric/__init__.py
Normal file
24
colossalai/nn/metric/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
from .accuracy_2d import Accuracy2D
|
||||
from .accuracy_2p5d import Accuracy2p5D
|
||||
from .accuracy_3d import Accuracy3D
|
||||
|
||||
_parallel_accuracy = {
|
||||
'2d': Accuracy2D,
|
||||
'2.5d': Accuracy2p5D,
|
||||
'3d': Accuracy3D,
|
||||
}
|
||||
|
||||
|
||||
class Accuracy(nn.Module):
|
||||
def __init__(self, tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.acc = calc_acc
|
||||
else:
|
||||
self.acc = _parallel_accuracy[tensor_parallel]()
|
||||
|
||||
def forward(self, *args):
|
||||
return self.acc(*args)
|
6
colossalai/nn/metric/_utils.py
Normal file
6
colossalai/nn/metric/_utils.py
Normal file
@ -0,0 +1,6 @@
|
||||
import torch
|
||||
|
||||
def calc_acc(logits, targets):
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(targets == preds)
|
||||
return correct
|
17
colossalai/nn/metric/accuracy_2d.py
Normal file
17
colossalai/nn/metric/accuracy_2d.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
||||
|
||||
class Accuracy2D(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_2d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2d.apply(correct)
|
||||
return correct
|
17
colossalai/nn/metric/accuracy_2p5d.py
Normal file
17
colossalai/nn/metric/accuracy_2p5d.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch
|
||||
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
||||
|
||||
class Accuracy2p5D(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_2p5d(targets)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_2p5d.apply(correct)
|
||||
return correct
|
21
colossalai/nn/metric/accuracy_3d.py
Normal file
21
colossalai/nn/metric/accuracy_3d.py
Normal file
@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_batch_3d
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from torch import nn
|
||||
|
||||
from ._utils import calc_acc
|
||||
|
||||
|
||||
class Accuracy3D(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
||||
def forward(self, logits, targets):
|
||||
with torch.no_grad():
|
||||
targets = split_batch_3d(targets, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
correct = calc_acc(logits, targets)
|
||||
correct = reduce_by_batch_3d.apply(correct, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
return correct
|
@ -164,31 +164,35 @@ class Trainer:
|
||||
if epoch is None:
|
||||
progress = tqdm(progress, desc='[Train]')
|
||||
else:
|
||||
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
||||
progress = tqdm(progress, desc=f'[Epoch {epoch} / Train]')
|
||||
|
||||
self._call_hooks('before_train_epoch')
|
||||
self._call_timer(action='start', item='train-epoch')
|
||||
self._call_timer(action='start', item='Train-epoch')
|
||||
for i in progress:
|
||||
self._call_hooks('before_train_iter')
|
||||
self._call_timer(action='start', item='train-step')
|
||||
self._call_timer(action='start', item='Train-step')
|
||||
|
||||
# run 1 training step
|
||||
self.engine.zero_grad()
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=False, return_loss=True)
|
||||
self.engine.step()
|
||||
self._call_timer(action='stop', item='train-step', keep_in_history=True)
|
||||
self._call_timer(action='stop', item='Train-step', keep_in_history=True)
|
||||
self._call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
||||
if display_progress:
|
||||
if 'step_metrics' in self.states:
|
||||
progress.set_postfix(**self.states['step_metrics'])
|
||||
|
||||
# stop when max iter is reached
|
||||
if self._exceed_max_step():
|
||||
break
|
||||
|
||||
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
|
||||
self._call_timer(action='stop', item='Train-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_train_epoch')
|
||||
self._call_timer(action='reset', item='train-step')
|
||||
self._call_timer(action='reset', item='Train-step')
|
||||
|
||||
def _eval(self,
|
||||
test_dataloader: DataLoader,
|
||||
@ -206,25 +210,30 @@ class Trainer:
|
||||
if display_progress:
|
||||
desc = 'Evaluation'
|
||||
if epoch is not None:
|
||||
desc = '[Epoch %d val]' % epoch
|
||||
desc = '[Epoch %d / Test]' % epoch
|
||||
progress = tqdm(progress, desc=desc)
|
||||
|
||||
self._call_hooks('before_test_epoch')
|
||||
self._call_timer(action='start', item='test-epoch')
|
||||
self._call_timer(action='start', item='Test-epoch')
|
||||
with torch.no_grad():
|
||||
for _ in progress:
|
||||
self._call_hooks('before_test_iter')
|
||||
self._call_timer(action='start', item='test-step')
|
||||
self._call_timer(action='start', item='Test-step')
|
||||
logits, label, loss = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=True, return_loss=True)
|
||||
self._call_timer(action='stop', item='test-step', keep_in_history=True)
|
||||
self._call_timer(action='stop', item='Test-step', keep_in_history=True)
|
||||
self._call_hooks('after_test_iter',
|
||||
output=(logits, label, loss))
|
||||
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
|
||||
|
||||
if display_progress:
|
||||
if 'step_metrics' in self.states:
|
||||
progress.set_postfix(**self.states['step_metrics'])
|
||||
|
||||
self._call_timer(action='stop', item='Test-epoch', keep_in_history=True)
|
||||
self._call_hooks('after_test_epoch')
|
||||
self._call_hooks('after_test')
|
||||
self._call_timer(action='reset', item='test-step')
|
||||
self._call_timer(action='reset', item='test-epoch')
|
||||
self._call_timer(action='reset', item='Test-step')
|
||||
self._call_timer(action='reset', item='Test-epoch')
|
||||
|
||||
def _exceed_max_step(self):
|
||||
return self._max_steps is not None and self._cur_step >= self._max_steps
|
||||
@ -317,7 +326,7 @@ class Trainer:
|
||||
ranks=[0])
|
||||
break
|
||||
self._call_hooks('after_train')
|
||||
self._call_timer('reset', 'train-epoch')
|
||||
self._call_timer('reset', 'Train-epoch')
|
||||
|
||||
def evaluate(self,
|
||||
test_dataloader: DataLoader,
|
||||
@ -374,4 +383,4 @@ class Trainer:
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.schedule.forward_backward_step(
|
||||
self.engine, data_iter, forward_only=True, return_loss=False)
|
||||
return output
|
||||
return output
|
@ -1,15 +1,12 @@
|
||||
from ._base_hook import BaseHook
|
||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||
from ._metric_hook import (LossHook, Accuracy2DHook, AccuracyHook, MetricHook,
|
||||
Accuracy1DHook, Accuracy2p5DHook, Accuracy3DHook)
|
||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||
from ._checkpoint_hook import LoadCheckpointHook, SaveCheckpointHook
|
||||
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
|
||||
TensorboardHook)
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
|
||||
|
||||
__all__ = [
|
||||
'BaseHook', 'MetricHook',
|
||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||
'Accuracy1DHook', 'Accuracy2p5DHook', 'Accuracy3DHook',
|
||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||
'LRSchedulerHook'
|
||||
'BaseHook', 'MetricHook', 'LoadCheckpointHook', 'SaveCheckpointHook', 'LossHook', 'AccuracyHook',
|
||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook',
|
||||
'ThroughputHook', 'LogMetricByStepHook'
|
||||
]
|
||||
|
@ -16,11 +16,11 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
|
||||
from ._base_hook import BaseHook
|
||||
|
||||
|
||||
def _format_number(val):
|
||||
def _format_number(val, prec=5):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.5g}'
|
||||
return f'{val:.{prec}g}'
|
||||
elif torch.is_tensor(val) and torch.is_floating_point(val):
|
||||
return f'{val.item():.5g}'
|
||||
return f'{val.item():.{prec}g}'
|
||||
return val
|
||||
|
||||
|
||||
@ -37,6 +37,24 @@ class LogByEpochHook(BaseHook):
|
||||
return trainer.cur_epoch % self._interval == 0
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByStepHook(BaseHook):
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByEpochHook(LogByEpochHook):
|
||||
"""Specialized Hook to record the metric to log.
|
||||
@ -61,7 +79,7 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
msg.append(
|
||||
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg = ', '.join(msg)
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
@ -69,15 +87,15 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
msg = self._get_str(trainer=trainer, mode='train')
|
||||
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}')
|
||||
# f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
if self._is_epoch_to_log(trainer):
|
||||
msg = self._get_str(trainer=trainer, mode='test')
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
# f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@ -131,8 +149,7 @@ class TensorboardHook(BaseHook):
|
||||
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
|
||||
def _log_by_iter(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
@ -141,16 +158,14 @@ class TensorboardHook(BaseHook):
|
||||
val = metric_calculator.get_last_step_value()
|
||||
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||
trainer.cur_step)
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def _log_by_epoch(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
if metric_calculator.epoch_only:
|
||||
val = metric_calculator.get_accumulated_value()
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||
trainer.cur_step)
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self._log_by_iter(trainer, mode='test')
|
||||
@ -178,15 +193,13 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
timer: MultiTimer,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
ignore_num_train_steps: int = 0
|
||||
) -> None:
|
||||
ignore_num_train_steps: int = 0) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._timer = timer
|
||||
self._log_eval = log_eval
|
||||
@ -197,40 +210,39 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||
self._ignore_num_train_steps = ignore_num_train_steps
|
||||
self._is_train_step_history_trimmed = False
|
||||
|
||||
def _get_message(self):
|
||||
def _get_message(self, mode):
|
||||
msg = []
|
||||
for timer_name, timer in self._timer:
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
if timer_name == 'train-step' and not self._is_train_step_history_trimmed:
|
||||
timer._history = timer._history[self._ignore_num_train_steps:]
|
||||
self._is_train_step_history_trimmed = True
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s')
|
||||
else:
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
|
||||
if timer_name.startswith(mode):
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
if timer_name == 'Train-step' and not self._is_train_step_history_trimmed:
|
||||
timer._history = timer._history[self._ignore_num_train_steps:]
|
||||
self._is_train_step_history_trimmed = True
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s'
|
||||
)
|
||||
else:
|
||||
msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
|
||||
|
||||
msg = ', '.join(msg)
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={trainer.steps_per_epoch}')
|
||||
msg = self._get_message('Train')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
"""Writes log after finishing a testing epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
msg = self._get_message('Test')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@ -246,14 +258,12 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False
|
||||
) -> None:
|
||||
report_cpu: bool = False) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
@ -262,20 +272,16 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
"""Resets before training.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage('before-train', self.logger)
|
||||
report_memory_usage('Before-train', self.logger)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage(
|
||||
f'After Train - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
|
||||
self.logger)
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger)
|
||||
|
||||
def after_test(self, trainer):
|
||||
"""Reports after testing.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
report_memory_usage(
|
||||
f'After Test - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
|
||||
self.logger)
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger)
|
||||
|
@ -1,9 +1,7 @@
|
||||
from colossalai.registry import HOOKS
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.builder import build_lr_scheduler
|
||||
from colossalai.registry import HOOKS
|
||||
from ._metric_hook import MetricHook
|
||||
from ..metric import LearningRate
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@ -19,28 +17,28 @@ class LRSchedulerHook(MetricHook):
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr_scheduler,
|
||||
by_epoch: bool,
|
||||
store_lr_in_state: bool = True,
|
||||
priority: int = 1,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
lr_scheduler,
|
||||
by_epoch: bool,
|
||||
store_lr_in_state: bool = True,
|
||||
priority: int = 1,
|
||||
):
|
||||
super().__init__(priority=priority)
|
||||
self.by_epoch = by_epoch
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.store_lr_in_state = store_lr_in_state
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
if not self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
|
@ -1,11 +1,209 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication import all_reduce
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and it's always used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
|
||||
@property
|
||||
def epoch_only(self):
|
||||
"""Returns :attr:`epoch_only`.
|
||||
"""
|
||||
return self._epoch_only
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the metric to it's initial state.
|
||||
By default, this is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""Updates the metric's state using the passed batch output.
|
||||
By default, this is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self):
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_accumulated_value(self):
|
||||
"""Computes the metric based on it's accumulated state.
|
||||
By default, this is called at the end of each epoch.
|
||||
|
||||
:return: the actual quantity of interest
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def is_better(a, b) -> bool:
|
||||
"""Compares a and b, and returns whether a is better than b
|
||||
|
||||
:return: The result of comparison
|
||||
:rtype: bool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LossMetric(Metric):
|
||||
"""A metric collector for loss.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
|
||||
"""
|
||||
self.last_step_loss.zero_()
|
||||
self.accum_loss.zero_()
|
||||
self.count = 0
|
||||
|
||||
def update(self, loss) -> None:
|
||||
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
|
||||
It expects the output has loss.
|
||||
|
||||
:param loss: Current loss of the output
|
||||
"""
|
||||
# expect output to be logits, label and loss
|
||||
loss_ = loss.detach()
|
||||
self.last_step_loss.copy_(loss_)
|
||||
self.accum_loss.add_(loss_)
|
||||
self.count += 1
|
||||
|
||||
def get_accumulated_value(self):
|
||||
"""Returns accumulated loss.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
|
||||
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self):
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
class LearningRateMetric(Metric):
|
||||
"""A metric collector for learning rate.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.lr = initial_lr
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def update(self, lr) -> None:
|
||||
self.lr = lr
|
||||
|
||||
def get_last_step_value(self):
|
||||
return self.lr
|
||||
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class AccuracyMetric(Metric):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
self.last_step_correct.zero_()
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, targets) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param label: The labels of the input data
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(targets, (list, tuple)):
|
||||
targets = targets[0]
|
||||
# update
|
||||
correct = self.acc(logits, targets)
|
||||
|
||||
self.last_step_sum.fill_(targets.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self):
|
||||
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
|
||||
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
|
||||
return (self.last_step_correct / self.last_step_sum).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
|
||||
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class MetricHook(BaseHook):
|
||||
@ -19,10 +217,10 @@ class MetricHook(BaseHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
priority: int,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
priority: int,
|
||||
):
|
||||
super().__init__(priority)
|
||||
self._is_stage_to_compute = is_no_pp_or_last_stage()
|
||||
|
||||
@ -40,7 +238,6 @@ class LossHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
@ -48,14 +245,12 @@ class LossHook(MetricHook):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
|
||||
if self._is_stage_to_compute:
|
||||
self.train_loss = Loss(epoch_only=False)
|
||||
self.test_loss = Loss(epoch_only=True)
|
||||
self.train_loss = LossMetric(epoch_only=False)
|
||||
self.test_loss = LossMetric(epoch_only=True)
|
||||
|
||||
# register the metric calculator
|
||||
trainer.states['metrics']['train'][
|
||||
self.train_loss.__class__.__name__] = self.train_loss
|
||||
trainer.states['metrics']['test'][
|
||||
self.test_loss.__class__.__name__] = self.test_loss
|
||||
trainer.states['metrics']['train']['Loss'] = self.train_loss
|
||||
trainer.states['metrics']['test']['Loss'] = self.test_loss
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
@ -74,124 +269,6 @@ class LossHook(MetricHook):
|
||||
self.test_loss.update(loss)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy1DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy1D`.
|
||||
It acts the same as :class:`AccuracyHook`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy1D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy2D`.
|
||||
It acts the same as :class:`AccuracyHook`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy2D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2p5DHook(MetricHook):
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy2p5D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy3DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy3D`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy3D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class AccuracyHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy`.
|
||||
@ -201,22 +278,87 @@ class AccuracyHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
self.accuracy_func = accuracy_func
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy(epoch_only=True)
|
||||
self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
trainer.states['metrics']['test']['Accuracy'] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
self.metric.update(logits, targets)
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.accumulated_num_samples.zero_()
|
||||
self.accumulated_used_time.zero_()
|
||||
self.last_step_num_samples.zero_()
|
||||
self.last_step_used_time.zero_()
|
||||
|
||||
def update(self, tensor, time) -> None:
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
tensor = tensor[0]
|
||||
self.last_step_num_samples.fill_(tensor.size(0))
|
||||
self.last_step_used_time.fill_(time)
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
|
||||
def get_last_step_value(self):
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class ThroughputHook(MetricHook):
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = ThroughputMetric(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['train']['Throughput'] = self.metric
|
||||
trainer.states['metrics']['test']['Throughput'] = self.metric
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
@ -1,356 +0,0 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication import all_gather
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D,
|
||||
WEIGHT_GROUP_3D)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer._parallel_utilities import _gather
|
||||
from colossalai.nn.layer.parallel_3d._utils import (get_last_group,
|
||||
get_parallel_mode_from_env)
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and it's always used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
|
||||
@property
|
||||
def epoch_only(self):
|
||||
"""Returns :attr:`epoch_only`.
|
||||
"""
|
||||
return self._epoch_only
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the metric to it's initial state.
|
||||
By default, this is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""Updates the metric's state using the passed batch output.
|
||||
By default, this is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self):
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_accumulated_value(self):
|
||||
"""Computes the metric based on it's accumulated state.
|
||||
By default, this is called at the end of each epoch.
|
||||
|
||||
:return: the actual quantity of interest
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def is_better(a, b) -> bool:
|
||||
"""Compares a and b, and returns whether a is better than b
|
||||
|
||||
:return: The result of comparison
|
||||
:rtype: bool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Loss(Metric):
|
||||
"""A metric collector for loss.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
|
||||
"""
|
||||
self.last_step_loss.zero_()
|
||||
self.accum_loss.zero_()
|
||||
self.count = 0
|
||||
|
||||
def update(self, loss) -> None:
|
||||
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
|
||||
It expects the output has loss.
|
||||
|
||||
:param loss: Current loss of the output
|
||||
"""
|
||||
# expect output to be logits, label and loss
|
||||
loss_ = loss.detach()
|
||||
self.last_step_loss.copy_(loss_)
|
||||
self.accum_loss.add_(loss_)
|
||||
self.count += 1
|
||||
|
||||
def get_accumulated_value(self):
|
||||
"""Returns accumulated loss.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(self.accum_loss,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self):
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
class LearningRate(Metric):
|
||||
"""A metric collector for learning rate.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.lr = 0.
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def update(self, lr) -> None:
|
||||
self.lr = lr
|
||||
|
||||
def get_last_step_value(self):
|
||||
return self.lr
|
||||
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
self.last_step_correct.zero_()
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param label: The labels of the input data
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self):
|
||||
dist.all_reduce(self.last_step_sum,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.all_reduce(self.last_step_correct,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
return (self.last_step_sum / self.last_step_correct).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
dist.all_reduce(self.accumulated_sum,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.all_reduce(self.accumulated_correct,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
class Accuracy2D(Accuracy):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks. This class is the same as :class:`Accuracy` but used in 2D
|
||||
model parallelism.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
logits = _gather(logits, ParallelMode.PARALLEL_2D_ROW, 1)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
0,
|
||||
)
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
class Accuracy1D(Accuracy):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks. This class is the same as :class:`Accuracy` but used in 2D
|
||||
model parallelism.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_1D,
|
||||
1
|
||||
)
|
||||
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
|
||||
class Accuracy2p5D(Accuracy):
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
logits = _gather(logits, ParallelMode.PARALLEL_2P5D_ROW, 1)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
0,
|
||||
)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
0,
|
||||
)
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class Accuracy3D(Accuracy):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks. This class is the same as :class:`Accuracy` but used in 3D
|
||||
model parallelism.
|
||||
|
||||
:param input_parallel_mode: The parallel mode of the input, generally it should be `ParallelMode.PARALLEL_3D_OUTPUT`
|
||||
:type input_parallel_mode: `ParallelMode`
|
||||
:param weight_parallel_mode: The parallel mode of the weight, generally it should be `ParallelMode.PARALLEL_3D_WEIGHT`
|
||||
:type weight_parallel_mode: `ParallelMode`
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only):
|
||||
# input_parallel_mode, weight_parallel_mode):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.depth = int(os.environ['DEPTH_3D'])
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
|
||||
self.weight_parallel_mode)
|
||||
|
||||
def update(self, logits, target):
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(target, (list, tuple)):
|
||||
target = target[0]
|
||||
|
||||
batch_size = target.size(0)
|
||||
|
||||
j = gpc.get_local_rank(self.input_parallel_mode)
|
||||
i = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
target = torch.chunk(target, self.depth, dim=0)[i]
|
||||
target = torch.chunk(target, self.depth, dim=0)[j]
|
||||
|
||||
logits = all_gather(logits, -1, self.output_parallel_mode)
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
prediction = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(prediction == target)
|
||||
|
||||
dist.all_reduce(correct, group=gpc.get_group(self.input_parallel_mode))
|
||||
dist.all_reduce(correct,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
|
||||
self.last_step_sum.fill_(batch_size)
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
@ -48,14 +48,14 @@ def report_memory_usage(message, logger=None, report_cpu=False):
|
||||
gpu_cached = bytes_to_MB(torch.cuda.memory_reserved())
|
||||
gpu_max_cached = bytes_to_MB(torch.cuda.max_memory_reserved())
|
||||
|
||||
full_log = f"{message} - GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, \
|
||||
cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
||||
full_log = f"{message}: GPU: allocated {gpu_allocated} MB, max allocated {gpu_max_allocated} MB, " \
|
||||
+ f"cached: {gpu_cached} MB, max cached: {gpu_max_cached} MB"
|
||||
|
||||
if report_cpu:
|
||||
# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
|
||||
gc.collect()
|
||||
vm_stats=psutil.virtual_memory()
|
||||
vm_used=bytes_to_MB(vm_stats.total - vm_stats.available)
|
||||
vm_stats = psutil.virtual_memory()
|
||||
vm_used = bytes_to_MB(vm_stats.total - vm_stats.available)
|
||||
full_log += f", CPU Virtual Memory: used = {vm_used} MB, percent = {vm_stats.percent}%"
|
||||
|
||||
if logger is None:
|
||||
|
@ -13,9 +13,7 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py',
|
||||
host='localhost',
|
||||
port=29500)
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
@ -1,22 +1,22 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import os
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader, MultiTimer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CosineAnnealingLR
|
||||
from colossalai.nn.metric import Accuracy
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from torchvision import transforms
|
||||
from colossalai.trainer import hooks, Trainer
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet34
|
||||
from colossalai.nn import CosineAnnealingLR
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py',
|
||||
host='localhost',
|
||||
port=29500)
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
@ -93,7 +93,7 @@ def main():
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
|
||||
hooks.AccuracyHook(),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
hooks.LogTimingByEpochHook(timer, logger),
|
||||
|
@ -19,5 +19,5 @@ dataset = dict(
|
||||
)
|
||||
|
||||
gradient_accumulation=2
|
||||
gradient_clipping=1.0
|
||||
clip_grad_norm=1.0
|
||||
|
||||
|
@ -20,4 +20,4 @@ dataset = dict(
|
||||
)
|
||||
|
||||
gradient_accumulation=1
|
||||
gradient_clipping=1.0
|
||||
clip_grad_norm=1.0
|
||||
|
@ -1,3 +1,4 @@
|
||||
from colossalai.nn.metric import Accuracy
|
||||
import torch
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
@ -40,9 +41,7 @@ def build_dataset_test():
|
||||
)
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./le_config.py',
|
||||
host='localhost',
|
||||
port=29500)
|
||||
colossalai.launch_from_torch(config='./le_config.py')
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
@ -81,7 +80,7 @@ def main():
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
@ -41,9 +41,7 @@ def build_dataset_test():
|
||||
)
|
||||
|
||||
def main():
|
||||
colossalai.launch_from_torch(config='./config.py',
|
||||
host='localhost',
|
||||
port=29500)
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
@ -39,11 +39,7 @@ In your training script:
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_torch(config=args.config,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend
|
||||
)
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
```
|
||||
|
||||
In your terminal
|
||||
|
@ -11,7 +11,7 @@ fp16 = dict(
|
||||
)
|
||||
|
||||
gradient_accumulation = 16
|
||||
gradient_clipping = 1.0
|
||||
clip_grad_norm = 1.0
|
||||
|
||||
dali = dict(
|
||||
# root='./dataset/ILSVRC2012_1k',
|
||||
|
@ -2,6 +2,7 @@ import glob
|
||||
from math import log
|
||||
import os
|
||||
import colossalai
|
||||
from colossalai.nn.metric import Accuracy
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
@ -54,11 +55,15 @@ def main():
|
||||
# initialize distributed setting
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# launch from slurm batch job
|
||||
colossalai.launch_from_slurm(config=args.config,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
backend=args.backend
|
||||
)
|
||||
# launch from torch
|
||||
# colossalai.launch_from_torch(config=args.config)
|
||||
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
@ -91,7 +96,7 @@ def main():
|
||||
# build hooks
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.AccuracyHook(),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy()),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
|
||||
TotalBatchsizeHook(),
|
||||
|
@ -0,0 +1 @@
|
||||
from .vit import *
|
@ -1,208 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.registry import MODELS
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer3D',
|
||||
'vit_tiny_1d_patch4_32',
|
||||
'vit_tiny_1d_patch16_224',
|
||||
'vit_tiny_1d_patch16_384',
|
||||
'vit_small_1d_patch16_224',
|
||||
'vit_small_1d_patch16_384',
|
||||
'vit_small_1d_patch32_224',
|
||||
'vit_small_1d_patch32_384',
|
||||
'vit_base_1d_patch16_224',
|
||||
'vit_base_1d_patch16_384',
|
||||
'vit_base_1d_patch32_224',
|
||||
'vit_base_1d_patch32_384',
|
||||
'vit_large_1d_patch16_224',
|
||||
'vit_large_1d_patch16_384',
|
||||
'vit_large_1d_patch32_224',
|
||||
'vit_large_1d_patch32_384',
|
||||
]
|
||||
|
||||
|
||||
class ViTBlock1D(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.attn = col_nn.ViTSelfAttention1D(dim, num_heads, attn_drop, drop)
|
||||
self.drop_path = col_nn.VanillaViTDropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.mlp = col_nn.ViTMLP1D(dim, 1, drop, 'gelu')
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer1D(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
embed_dim: int = 768,
|
||||
hidden_dim: int = 3072,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = col_nn.ViTPatchEmbedding1D(
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_dim,
|
||||
drop_rate,
|
||||
)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
ViTBlock1D(embed_dim, num_heads, hidden_dim,
|
||||
drop_rate, attn_drop_rate, dpr[i])
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = nn.LayerNorm(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
self.head = col_nn.ViTHead1D(hidden_dim, num_classes)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer1D(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_1d_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
|
||||
depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_1d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_1d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_1d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_1d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_1d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_1d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_1d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_1d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_1d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_1d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_1d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_1d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
@ -1 +0,0 @@
|
||||
from .vit import *
|
@ -1,219 +0,0 @@
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai import nn as clsl_nn
|
||||
from colossalai.registry import MODELS
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer2D',
|
||||
'vit_tiny_2d_patch4_32',
|
||||
'vit_tiny_2d_patch16_224',
|
||||
'vit_tiny_2d_patch16_384',
|
||||
'vit_small_2d_patch16_224',
|
||||
'vit_small_2d_patch16_384',
|
||||
'vit_small_2d_patch32_224',
|
||||
'vit_small_2d_patch32_384',
|
||||
'vit_base_2d_patch16_224',
|
||||
'vit_base_2d_patch16_384',
|
||||
'vit_base_2d_patch32_224',
|
||||
'vit_base_2d_patch32_384',
|
||||
'vit_large_2d_patch16_224',
|
||||
'vit_large_2d_patch16_384',
|
||||
'vit_large_2d_patch32_224',
|
||||
'vit_large_2d_patch32_384',
|
||||
]
|
||||
|
||||
|
||||
class ViTBlock2D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int = 4,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: str = 'gelu'):
|
||||
super().__init__()
|
||||
self.norm1 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
|
||||
self.attn = clsl_nn.ViTSelfAttention2D(dim, num_heads, attn_drop, drop)
|
||||
self.drop_path = clsl_nn.VanillaViTDropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.norm2 = clsl_nn.LayerNorm2D(dim, eps=1e-6)
|
||||
self.mlp = clsl_nn.ViTMLP2D(dim, mlp_ratio, act_layer, drop)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.attn(self.norm1(x))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = x + self.drop_path(y)
|
||||
y = self.mlp(self.norm2(x))
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = x + self.drop_path(y)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer2D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
act_layer: str = 'gelu'):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = clsl_nn.ViTPatchEmbedding2D(
|
||||
img_size, patch_size, embed_dim, in_chans
|
||||
)
|
||||
|
||||
self.splitter = clsl_nn.ViTInputSplitter2D()
|
||||
|
||||
self.token_fuser = clsl_nn.ViTTokenFuser2D(
|
||||
img_size, patch_size, embed_dim, drop_rate
|
||||
)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
ViTBlock2D(embed_dim, num_heads, mlp_ratio, drop_rate,
|
||||
attn_drop_rate, dpr[i], act_layer)
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = clsl_nn.LayerNorm2D(embed_dim, eps=1e-6)
|
||||
self.head = clsl_nn.ViTHead2D(self.num_features, num_classes) if num_classes > 0 \
|
||||
else nn.Identity()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.splitter(x)
|
||||
x = self.token_fuser(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer2D(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
|
||||
depth=6, num_heads=8, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_2d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
@ -1 +0,0 @@
|
||||
from .vit import *
|
@ -1,209 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.registry import MODELS
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer3D',
|
||||
'vit_tiny_3d_patch4_32',
|
||||
'vit_tiny_3d_patch16_224',
|
||||
'vit_tiny_3d_patch16_384',
|
||||
'vit_small_3d_patch16_224',
|
||||
'vit_small_3d_patch16_384',
|
||||
'vit_small_3d_patch32_224',
|
||||
'vit_small_3d_patch32_384',
|
||||
'vit_base_3d_patch16_224',
|
||||
'vit_base_3d_patch16_384',
|
||||
'vit_base_3d_patch32_224',
|
||||
'vit_base_3d_patch32_384',
|
||||
'vit_large_3d_patch16_224',
|
||||
'vit_large_3d_patch16_384',
|
||||
'vit_large_3d_patch32_224',
|
||||
'vit_large_3d_patch32_384',
|
||||
]
|
||||
|
||||
|
||||
class ViTBlock3D(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm3D(
|
||||
dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
|
||||
self.attn = col_nn.ViTSelfAttention3D(dim, num_heads, attn_drop, drop)
|
||||
self.drop_path = col_nn.VanillaViTDropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm3D(dim, ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT, eps=1e-6)
|
||||
self.mlp = col_nn.ViTMLP3D(hidden_dim, 1, drop, 'gelu')
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer3D(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
embed_dim: int = 768,
|
||||
hidden_dim: int = 3072,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
|
||||
self.patch_embed = col_nn.ViTPatchEmbedding3D(
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_dim,
|
||||
drop_rate,
|
||||
)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
ViTBlock3D(embed_dim, num_heads, hidden_dim,
|
||||
drop_rate, attn_drop_rate, dpr[i])
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
self.norm = col_nn.LayerNorm3D(embed_dim, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
self.head = col_nn.ViTHead3D(hidden_dim, num_classes)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer3D(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, embed_dim=512,
|
||||
depth=6, num_heads=8, hidden_dim=512, num_classes=10, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192,
|
||||
depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=192, depth=12, num_heads=3, hidden_dim=768, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=384,
|
||||
depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=384, depth=12, num_heads=6, hidden_dim=1536, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768,
|
||||
depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=768, depth=12, num_heads=12, hidden_dim=3072, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch16_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch32_224(**kwargs):
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024,
|
||||
depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_3d_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32,
|
||||
embed_dim=1024, depth=24, num_heads=16, hidden_dim=4096, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
487
model_zoo/vit/vit.py
Normal file
487
model_zoo/vit/vit.py
Normal file
@ -0,0 +1,487 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from colossalai import nn as col_nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.registry import LAYERS, MODELS
|
||||
from colossalai.utils import checkpoint
|
||||
from torch import dtype, nn
|
||||
|
||||
__all__ = [
|
||||
'VisionTransformer',
|
||||
'vit_lite_depth7_patch4_32',
|
||||
'vit_tiny_patch4_32',
|
||||
'vit_tiny_patch16_224',
|
||||
'vit_tiny_patch16_384',
|
||||
'vit_small_patch16_224',
|
||||
'vit_small_patch16_384',
|
||||
'vit_small_patch32_224',
|
||||
'vit_small_patch32_384',
|
||||
'vit_base_patch16_224',
|
||||
'vit_base_patch16_384',
|
||||
'vit_base_patch32_224',
|
||||
'vit_base_patch32_384',
|
||||
'vit_large_patch16_224',
|
||||
'vit_large_patch16_384',
|
||||
'vit_large_patch32_224',
|
||||
'vit_large_patch32_384',
|
||||
]
|
||||
|
||||
_init_rules = dict(
|
||||
torch=dict(
|
||||
embed=dict(
|
||||
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer=col_nn.init.zeros_(),
|
||||
),
|
||||
transformer=dict(
|
||||
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
||||
),
|
||||
head=dict(
|
||||
weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
|
||||
),
|
||||
),
|
||||
jax=dict(
|
||||
embed=dict(
|
||||
weight_initializer=col_nn.init.lecun_normal_(),
|
||||
bias_initializer=col_nn.init.zeros_(),
|
||||
position_embed_initializer=col_nn.init.trunc_normal_(std=.02),
|
||||
),
|
||||
transformer=dict(
|
||||
weight_initializer=col_nn.init.xavier_uniform_(),
|
||||
bias_initializer=col_nn.init.normal_(std=1e-6),
|
||||
),
|
||||
head=dict(
|
||||
weight_initializer=col_nn.init.zeros_(),
|
||||
bias_initializer=col_nn.init.zeros_(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embedding_dim: int,
|
||||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.patch_embed = col_nn.PatchEmbedding(img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embedding_dim,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['embed'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTSelfAttention(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
attention_dropout: float,
|
||||
dropout: float,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.attention_head_size = dim // num_heads
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
self.query_key_value = col_nn.Linear(dim,
|
||||
3 * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||
self.dense = col_nn.Linear(dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=True,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def _forward(self, x):
|
||||
qkv = self.query_key_value(x)
|
||||
all_head_size = qkv.shape[-1] // 3
|
||||
num_attention_heads = all_head_size // self.attention_head_size
|
||||
new_qkv_shape = qkv.shape[:-1] + \
|
||||
(num_attention_heads, 3 * self.attention_head_size)
|
||||
qkv = qkv.view(new_qkv_shape)
|
||||
qkv = qkv.permute((0, 2, 1, 3))
|
||||
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
||||
|
||||
x = torch.matmul(q, k.transpose(-1, -2))
|
||||
x = x / math.sqrt(self.attention_head_size)
|
||||
x = self.softmax(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.attention_dropout(x)
|
||||
|
||||
x = torch.matmul(x, v)
|
||||
x = x.transpose(1, 2)
|
||||
new_context_layer_shape = x.size()[:-2] + (all_head_size, )
|
||||
x = x.reshape(new_context_layer_shape)
|
||||
|
||||
x = self.dense(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTMLP(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
mlp_ratio: int,
|
||||
activation: Callable,
|
||||
dropout: float,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.tensor_parallel = tensor_parallel
|
||||
|
||||
self.dense_1 = col_nn.Linear(dim,
|
||||
mlp_ratio * dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_col' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.activation = activation
|
||||
self.dense_2 = col_nn.Linear(mlp_ratio * dim,
|
||||
dim,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel='1d_row' if tensor_parallel == '1d' else tensor_parallel,
|
||||
**_init_rules[init_method]['transformer'])
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def _forward(self, x):
|
||||
x = self.dense_1(x)
|
||||
x = self.activation(x)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
x = self.dense_2(x)
|
||||
if self.tensor_parallel == '1d':
|
||||
x = self.dropout(x)
|
||||
else:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
x = self.dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
def _checkpoint_forward(self, x):
|
||||
return checkpoint(self._forward, x)
|
||||
|
||||
def forward(self, x):
|
||||
if self.checkpoint:
|
||||
return self._checkpoint_forward(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTHead(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_classes: int,
|
||||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
if representation_size:
|
||||
tensor_parallel_kwargs = {'tensor_parallel': '1d_col' if tensor_parallel == '1d' else tensor_parallel}
|
||||
if tensor_parallel == '1d':
|
||||
tensor_parallel_kwargs['gather_output'] = True
|
||||
self.representation = col_nn.Linear(dim,
|
||||
representation_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
**_init_rules[init_method]['head'],
|
||||
**tensor_parallel_kwargs)
|
||||
else:
|
||||
self.representation = None
|
||||
representation_size = dim
|
||||
|
||||
self.linear = col_nn.Classifier(representation_size,
|
||||
num_classes,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
tensor_parallel=tensor_parallel,
|
||||
**_init_rules[init_method]['head'])
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0]
|
||||
if self.representation is not None:
|
||||
x = self.representation(x)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class ViTBlock(nn.Module):
|
||||
def __init__(self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: int,
|
||||
activation: Callable,
|
||||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.attn = ViTSelfAttention(dim=dim,
|
||||
num_heads=num_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype, tensor_parallel=tensor_parallel)
|
||||
self.mlp = ViTMLP(dim=dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
activation=activation,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
dim: int = 768,
|
||||
mlp_ratio: int = 4,
|
||||
attention_dropout: float = 0.,
|
||||
dropout: float = 0.1,
|
||||
drop_path: float = 0.,
|
||||
activation: Callable = nn.functional.gelu,
|
||||
representation_size: int = None,
|
||||
dtype: dtype = None,
|
||||
bias: bool = True,
|
||||
checkpoint: bool = False,
|
||||
init_method: str = 'torch',
|
||||
tensor_parallel: str = None):
|
||||
super().__init__()
|
||||
|
||||
embed = ViTEmbedding(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embedding_dim=dim,
|
||||
dropout=dropout,
|
||||
dtype=dtype,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = [
|
||||
ViTBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
drop_path=dpr[i],
|
||||
activation=activation,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
checkpoint=checkpoint,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
) for i in range(depth)
|
||||
]
|
||||
|
||||
norm = col_nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
|
||||
head = ViTHead(
|
||||
dim=dim,
|
||||
num_classes=num_classes,
|
||||
representation_size=representation_size,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
init_method=init_method,
|
||||
tensor_parallel=tensor_parallel,
|
||||
)
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
embed,
|
||||
*blocks,
|
||||
norm,
|
||||
head,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layers(x)
|
||||
return x
|
||||
|
||||
|
||||
def _create_vit_model(**model_kwargs):
|
||||
model = VisionTransformer(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_lite_depth7_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, dim=256, depth=7, num_heads=4, mlp_ratio=2, num_classes=10, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_patch4_32(**kwargs):
|
||||
model_kwargs = dict(img_size=32, patch_size=4, dim=512, depth=6, num_heads=8, mlp_ratio=1, num_classes=10, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_patch16_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_tiny_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_patch16_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_patch32_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_small_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_patch16_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_patch32_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_base_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_patch16_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_patch16_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_patch32_224(**kwargs):
|
||||
model_kwargs = dict(img_size=224, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
||||
|
||||
|
||||
@MODELS.register_module
|
||||
def vit_large_patch32_384(**kwargs):
|
||||
model_kwargs = dict(img_size=384, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
return _create_vit_model(**model_kwargs)
|
74
tests/test_comm/test_comm.py
Normal file
74
tests/test_comm/test_comm.py
Normal file
@ -0,0 +1,74 @@
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1)))
|
||||
|
||||
SIZE = 8
|
||||
|
||||
|
||||
def check_all_gather():
|
||||
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
|
||||
tensor = tensor.to(get_current_device())
|
||||
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
tensor, op = all_gather(tensor, 0, ParallelMode.GLOBAL, async_op=True)
|
||||
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
op.wait()
|
||||
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def check_reduce_scatter():
|
||||
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
|
||||
tensor = tensor.to(get_current_device())
|
||||
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
tensor, op = reduce_scatter(tensor, 0, ParallelMode.GLOBAL, async_op=True)
|
||||
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
op.wait()
|
||||
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def check_all_reduce():
|
||||
tensor = torch.tensor([dist.get_rank() * SIZE + j for j in range(SIZE)])
|
||||
tensor = tensor.to(get_current_device())
|
||||
print('Before: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
tensor, op = all_reduce(tensor, ParallelMode.GLOBAL, async_op=True)
|
||||
print('After: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
op.wait()
|
||||
print('Complete: Rank {0} - {1}'.format(dist.get_rank(), tensor))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def check_layer(rank, world_size):
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30010, backend='nccl')
|
||||
|
||||
assert dist.get_rank() == gpc.get_global_rank()
|
||||
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
|
||||
|
||||
check_all_gather()
|
||||
check_reduce_scatter()
|
||||
check_all_reduce()
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_comm():
|
||||
world_size = 4
|
||||
run_func = partial(check_layer, world_size=world_size)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_comm()
|
@ -1,141 +0,0 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from colossalai.amp.amp_type import AMP_TYPE
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.logging import get_dist_logger
|
||||
import colossalai
|
||||
import torch
|
||||
import os
|
||||
from colossalai.builder import build_pipeline_model_from_cfg
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_dataloader, MultiTimer
|
||||
from colossalai.nn.loss import CrossEntropyLoss2D
|
||||
from colossalai.trainer.metric import Accuracy2D
|
||||
from colossalai.trainer import metric, hooks, Trainer
|
||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from colossalai.nn import LinearWarmupLR
|
||||
from tqdm import tqdm
|
||||
import vit_t_2d
|
||||
|
||||
BATCH_SIZE = 16
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=2,
|
||||
tensor=dict(size=4, mode='2d')
|
||||
),
|
||||
fp16=dict(
|
||||
mode=AMP_TYPE.TORCH
|
||||
),
|
||||
gradient_accumulation=2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
|
||||
def test_hybrid_parallel():
|
||||
parser = colossalai.get_default_parser()
|
||||
args = parser.parse_args()
|
||||
colossalai.launch_from_slurm(config=CONFIG,
|
||||
host=args.host,
|
||||
port=29500)
|
||||
|
||||
logger = get_dist_logger()
|
||||
# if gpc.get_global_rank() == 0:
|
||||
# logger.log_to_file('./logs/cifar10_2d_vit',
|
||||
# suffix='cifar10_2d_vit_ddp1_torch_amp_grad_accum_2_clip_grad_1', mode='w')
|
||||
|
||||
# build vit-t-32
|
||||
model = build_pipeline_model_from_cfg(vit_t_2d.model_cfg, num_chunks=1)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.RandomCrop(size=32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
test_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
train=False,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
|
||||
0.2023, 0.1994, 0.2010]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
add_sampler=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
test_dataloader = get_dataloader(dataset=test_dataset,
|
||||
add_sampler=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# build criterion
|
||||
criterion = CrossEntropyLoss2D()
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# lr_scheduler
|
||||
steps_per_epoch = GradAccumLrSchedulerByStep.compute_effective_steps_per_epoch(train_dataloader, accumulate_size=2)
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
warmup_steps = steps_per_epoch * WARMUP_EPOCHS
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
||||
|
||||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader, test_dataloader, lr_scheduler)
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
schedule = PipelineSchedule(num_microbatches=4)
|
||||
|
||||
trainer = Trainer(
|
||||
engine=engine,
|
||||
timer=timer,
|
||||
logger=logger,
|
||||
schedule=schedule
|
||||
)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
hooks.Accuracy2DHook(),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
]
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hybrid_parallel()
|
@ -1,3 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
python run_cifar10_vit2d_with_pipeline.py --host $HOST
|
@ -0,0 +1,103 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp.amp_type import AMP_TYPE
|
||||
from colossalai.builder import build_pipeline_model
|
||||
from colossalai.engine.schedule import PipelineSchedule
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import Accuracy, LinearWarmupLR
|
||||
from colossalai.nn.loss import CrossEntropyLoss
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.utils import MultiTimer, get_dataloader
|
||||
from colossalai.utils.gradient_accumulation import GradAccumLrSchedulerByStep
|
||||
from model_zoo.vit import vit_tiny_patch4_32
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
BATCH_SIZE = 16
|
||||
NUM_EPOCHS = 60
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(parallel=dict(pipeline=2, tensor=dict(size=2, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.TORCH),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=30000, backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
model = vit_tiny_patch4_32(tensor_parallel='1d')
|
||||
pipe_model = build_pipeline_model(model.layers, num_chunks=1)
|
||||
|
||||
# build dataloaders
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize(32),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']), train=True, download=True, transform=transform_train)
|
||||
test_dataset = CIFAR10(root=Path(os.environ['DATA']), train=False, transform=transform_test)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# build criterion
|
||||
criterion = CrossEntropyLoss(tensor_parallel='1d')
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(pipe_model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# lr_scheduler
|
||||
steps_per_epoch = GradAccumLrSchedulerByStep.compute_effective_steps_per_epoch(train_dataloader, accumulate_size=2)
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
warmup_steps = steps_per_epoch * WARMUP_EPOCHS
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
|
||||
|
||||
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(pipe_model, optimizer, criterion,
|
||||
train_dataloader, test_dataloader,
|
||||
lr_scheduler)
|
||||
|
||||
timer = MultiTimer()
|
||||
|
||||
schedule = PipelineSchedule(num_microbatches=4)
|
||||
|
||||
trainer = Trainer(engine=engine, timer=timer, logger=logger, schedule=schedule)
|
||||
|
||||
hook_list = [
|
||||
hooks.LossHook(),
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
hooks.AccuracyHook(accuracy_func=Accuracy(tensor_parallel='1d')),
|
||||
hooks.LogMetricByEpochHook(logger),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=NUM_EPOCHS,
|
||||
max_steps=5,
|
||||
test_dataloader=test_dataloader,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
# @pytest.mark.skip("This test requires more than 8 GPUs, you should invoke this test script using test.sh provided manually")
|
||||
def test_hybrid_parallel():
|
||||
world_size = 8
|
||||
run_func = partial(run_trainer, world_size=world_size)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hybrid_parallel()
|
@ -1,74 +0,0 @@
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
repo_path = str(Path(__file__).absolute().parents[2])
|
||||
sys.path.append(repo_path)
|
||||
|
||||
try:
|
||||
import model_zoo.vit.vision_transformer_from_config
|
||||
except ImportError:
|
||||
raise ImportError("model_zoo is not found, please check your path")
|
||||
|
||||
IMG_SIZE = 32
|
||||
PATCH_SIZE = 4
|
||||
DIM = 512
|
||||
NUM_ATTENTION_HEADS = 8
|
||||
NUM_CLASSES = 10
|
||||
DEPTH = 6
|
||||
|
||||
model_cfg = dict(
|
||||
type='VisionTransformerFromConfig',
|
||||
tensor_splitting_cfg=dict(
|
||||
type='ViTInputSplitter2D',
|
||||
),
|
||||
embedding_cfg=dict(
|
||||
type='ViTPatchEmbedding2D',
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=PATCH_SIZE,
|
||||
embed_dim=DIM,
|
||||
),
|
||||
token_fusion_cfg=dict(
|
||||
type='ViTTokenFuser2D',
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=PATCH_SIZE,
|
||||
embed_dim=DIM,
|
||||
drop_rate=0.1
|
||||
),
|
||||
norm_cfg=dict(
|
||||
type='LayerNorm2D',
|
||||
normalized_shape=DIM,
|
||||
eps=1e-6,
|
||||
),
|
||||
block_cfg=dict(
|
||||
type='ViTBlock',
|
||||
attention_cfg=dict(
|
||||
type='ViTSelfAttention2D',
|
||||
hidden_size=DIM,
|
||||
num_attention_heads=NUM_ATTENTION_HEADS,
|
||||
attention_dropout_prob=0.,
|
||||
hidden_dropout_prob=0.1,
|
||||
),
|
||||
droppath_cfg=dict(
|
||||
type='VanillaViTDropPath',
|
||||
),
|
||||
mlp_cfg=dict(
|
||||
type='ViTMLP2D',
|
||||
in_features=DIM,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=1
|
||||
),
|
||||
norm_cfg=dict(
|
||||
type='LayerNorm2D',
|
||||
normalized_shape=DIM,
|
||||
eps=1e-6,
|
||||
),
|
||||
),
|
||||
head_cfg=dict(
|
||||
type='ViTHead2D',
|
||||
hidden_size=DIM,
|
||||
num_classes=NUM_CLASSES,
|
||||
),
|
||||
embed_dim=DIM,
|
||||
depth=DEPTH,
|
||||
drop_path_rate=0.,
|
||||
)
|
@ -1,40 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
block_type='ResNetBasicBlock',
|
||||
layers=[2, 2, 2, 2],
|
||||
num_cls=10)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
|
||||
train_data = dict(dataset=dict(type='CIFAR10Dataset',
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform_pipeline=[
|
||||
dict(type='Resize',
|
||||
size=(IMG_SIZE, IMG_SIZE)),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize',
|
||||
mean=(0.5, 0.5, 0.5),
|
||||
std=(0.5, 0.5, 0.5))
|
||||
]),
|
||||
dataloader=dict(batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
num_workers=4,
|
||||
drop_last=True))
|
||||
|
||||
optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
|
@ -1,16 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
fp16 = dict(mode=AMP_TYPE.APEX)
|
@ -1,42 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from colossalai.engine import AMP_TYPE
|
||||
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
block_type='ResNetBasicBlock',
|
||||
layers=[2, 2, 2, 2],
|
||||
num_cls=10)
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
|
||||
train_data = dict(dataset=dict(type='CIFAR10Dataset',
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform_pipeline=[
|
||||
dict(type='Resize',
|
||||
size=(IMG_SIZE, IMG_SIZE)),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize',
|
||||
mean=(0.5, 0.5, 0.5),
|
||||
std=(0.5, 0.5, 0.5))
|
||||
]),
|
||||
dataloader=dict(batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
num_workers=4,
|
||||
drop_last=True))
|
||||
|
||||
optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
fp16 = dict(mode=AMP_TYPE.TORCH)
|
@ -1,46 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
# resnet 18
|
||||
model = dict(type='VanillaResNet',
|
||||
block_type='ResNetBasicBlock',
|
||||
layers=[2, 2, 2, 2],
|
||||
num_cls=10)
|
||||
|
||||
train_data = dict(dataset=dict(type='CIFAR10Dataset',
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform_pipeline=[
|
||||
dict(type='Resize',
|
||||
size=(IMG_SIZE, IMG_SIZE)),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize',
|
||||
mean=(0.5, 0.5, 0.5),
|
||||
std=(0.5, 0.5, 0.5))
|
||||
]),
|
||||
dataloader=dict(batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
num_workers=4,
|
||||
drop_last=True))
|
||||
|
||||
optimizer = dict(type='Adam', lr=0.001)
|
||||
|
||||
loss = dict(type='CrossEntropyLoss')
|
||||
|
||||
parallel = dict(
|
||||
pipeline=dict(size=4),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
|
||||
engine = dict(
|
||||
schedule=dict(
|
||||
num_microbatches=4
|
||||
)
|
||||
)
|
||||
num_epochs = 10
|
@ -4,7 +4,7 @@ from torch.nn import Parameter
|
||||
import time
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import Linear1D_Col, Linear1D_Row, TransformerMLP1D, TransformerSelfAttention1D, ViTMLP1D, ViTSelfAttention1D, ViTPatchEmbedding1D, ViTHead1D, ViTTokenFuser1D
|
||||
from colossalai.nn import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
|
||||
|
||||
@ -17,7 +17,7 @@ def check_linear_col():
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE, gather_output=True)
|
||||
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
@ -50,18 +50,20 @@ def check_linear_col():
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = C_master.clone()
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear_col gather_output forward: pass')
|
||||
print_rank_0('linear_col forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.detach()
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
@ -73,7 +75,7 @@ def check_linear_col():
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear_col gather_output backward: pass')
|
||||
print_rank_0('linear_col backward: pass')
|
||||
|
||||
|
||||
def check_linear_row():
|
||||
@ -84,12 +86,13 @@ def check_linear_row():
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, parallel_input=False)
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
@ -119,16 +122,18 @@ def check_linear_row():
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear_row no parallel_input forward: pass')
|
||||
print_rank_0('linear_row forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.detach()
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
@ -138,276 +143,4 @@ def check_linear_row():
|
||||
B_grad = B_master.grad
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear_row no parallel_input backward: pass')
|
||||
|
||||
|
||||
class Testvithead(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_head():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
head = ViTHead1D(INPUT_SIZE, NUM_CLASSES, dtype=dtype)
|
||||
torch.nn.init.zeros_(head.linear.bias)
|
||||
torch.nn.init.ones_(head.linear.weight)
|
||||
head = head.to(device)
|
||||
|
||||
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
torch.nn.init.zeros_(layer.linear.bias)
|
||||
torch.nn.init.ones_(layer.linear.weight)
|
||||
layer = layer.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = head(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer(A_master)
|
||||
# C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
print_rank_0('Rank {} head forward: {}'.format(i, check_equal(out, C_master)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
# grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
|
||||
# bwd_start = time.time()
|
||||
out.backward(grad_master)
|
||||
# bwd_end = time.time()
|
||||
# print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
# logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
# if j == 0:
|
||||
print_rank_0('Rank {} head backward (input_grad): {}'.format(
|
||||
i, check_equal(A_grad, A.grad)))
|
||||
|
||||
|
||||
class Testvitembed(torch.nn.Module):
|
||||
def __init__(self, img_size: int, patch_size: int, in_chans: int,
|
||||
embed_size: int, drop_prob: float) -> None:
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans,
|
||||
embed_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
num_patches = (img_size // patch_size)**2
|
||||
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
|
||||
self.pos_embed = torch.nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_size))
|
||||
self.pos_drop = torch.nn.Dropout(drop_prob)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = ViTPatchEmbedding1D(IMG_SIZE, 4, HIDDEN_SIZE)
|
||||
layer2 = ViTTokenFuser1D(IMG_SIZE, 4, HIDDEN_SIZE)
|
||||
torch.nn.init.zeros_(layer.proj.bias)
|
||||
torch.nn.init.ones_(layer.proj.weight)
|
||||
torch.nn.init.ones_(layer2.cls_token)
|
||||
torch.nn.init.ones_(layer2.pos_embed)
|
||||
layer = layer.to(device)
|
||||
layer2 = layer2.to(device)
|
||||
|
||||
layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
|
||||
torch.nn.init.zeros_(layer_master.proj.bias)
|
||||
torch.nn.init.ones_(layer_master.proj.weight)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer2(layer(A))
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start))
|
||||
# out_cls = out[:, 0]
|
||||
# out_tensor = out[:, 1:]
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
# if j == 0:
|
||||
# C_cls = C_master[:, 0]
|
||||
# C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
|
||||
# C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
|
||||
# logger.info('Rank {} embed forward (cls): {}'.format(
|
||||
# rank, check_equal(out_cls, C_cls)))
|
||||
# C = C_master[:, 1:]
|
||||
print_rank_0('Rank {} embed forward: {}'.format(i, check_equal(out, C_master)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
# cls_grad = grad_master[:, 0]
|
||||
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
|
||||
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
|
||||
# grad = grad_master[:, 1:]
|
||||
# grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
|
||||
bwd_start = time.time()
|
||||
out.backward(grad_master)
|
||||
bwd_end = time.time()
|
||||
print_rank_0(
|
||||
'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start))
|
||||
|
||||
C_master.backward(grad_master)
|
||||
|
||||
A_grad = A_master.grad
|
||||
print_rank_0('Rank {} embed backward (input_grad): {}'.format(i, check_equal(A_grad, A.grad)))
|
||||
|
||||
print_rank_0('Rank {} embed backward (cls_grad): {}'.format(
|
||||
i, check_equal(layer_master.cls_token.grad, layer2.cls_token.grad)))
|
||||
|
||||
print_rank_0('Rank {} embed backward (pos_embed_grad): {}'.format(
|
||||
i, check_equal(layer_master.pos_embed.grad, layer2.pos_embed.grad)))
|
||||
|
||||
print_rank_0('Rank {} embed backward (proj_weight_grad): {}'.format(
|
||||
i, check_equal(layer_master.proj.weight.grad, layer.proj.weight.grad)))
|
||||
|
||||
print_rank_0('Rank {} embed backward (proj_bias_grad): {}'.format(
|
||||
i, check_equal(layer_master.proj.bias.grad, layer.proj.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_attention():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
NUM_ATTENTION_HEADS = 2
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = ViTSelfAttention1D(
|
||||
HIDDEN_SIZE,
|
||||
NUM_ATTENTION_HEADS,
|
||||
0.5,
|
||||
0.5
|
||||
).to(device=device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
mask_shape = (BATCH_SIZE, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
|
||||
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
out = layer(A)
|
||||
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
print_rank_0('self attention forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('self attention backward: pass')
|
||||
|
||||
|
||||
def check_mlp():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = ViTMLP1D(
|
||||
HIDDEN_SIZE,
|
||||
4.0
|
||||
).to(device=device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
assert out.shape == (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
print_rank_0('mlp forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('mlp backward: pass')
|
||||
|
||||
|
||||
def check_patch_embedding():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = 4
|
||||
PATCH_SIZE = 2
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = ViTPatchEmbedding1D(
|
||||
INPUT_SIZE,
|
||||
PATCH_SIZE,
|
||||
HIDDEN_SIZE,
|
||||
).to(device=device)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
print('output size: ', out.size())
|
||||
assert out.shape == (BATCH_SIZE, 4, HIDDEN_SIZE)
|
||||
print_rank_0('patch embedding forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('patch embedding backward: pass')
|
||||
print_rank_0('linear_row backward: pass')
|
||||
|
@ -3,12 +3,12 @@
|
||||
|
||||
import torch
|
||||
|
||||
DEPTH = 2
|
||||
DEPTH = 4
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
IMG_SIZE = 16
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 10
|
||||
NUM_CLASSES = 8
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user