optimize engine and trainer test (#448)

This commit is contained in:
Frank Lee
2022-03-17 15:44:17 +08:00
committed by GitHub
parent 237d08e7ee
commit bb2790cf0b
6 changed files with 111 additions and 197 deletions

View File

@@ -7,10 +7,8 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import (recv_backward, recv_forward,
recv_tensor_meta, send_backward,
send_backward_recv_forward, send_forward,
send_forward_recv_backward,
from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward,
send_backward_recv_forward, send_forward, send_forward_recv_backward,
send_tensor_meta)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -18,17 +16,11 @@ from colossalai.initialize import launch
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port, get_current_device
BATCH_SIZE = 16
SEQ_LENGTH = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 4
SEQ_LENGTH = 2
HIDDEN_SIZE = 16
CONFIG = dict(
parallel=dict(
pipeline=dict(size=4),
tensor=dict(size=1, mode=None)
),
seed=1024
)
CONFIG = dict(parallel=dict(pipeline=dict(size=4), tensor=dict(size=1, mode=None)), seed=1024)
def check_equal(A, B):
@@ -41,8 +33,7 @@ def check_forward(output_tensor, rank, logger):
tensor = output_tensor.clone()
else:
tensor = recv_forward(output_tensor.shape)
logger.info('Rank {} received forward. Correct tensor: {}'.format(
rank, check_equal(tensor, output_tensor)))
logger.info('Rank {} received forward. Correct tensor: {}'.format(rank, check_equal(tensor, output_tensor)))
if not gpc.is_last_rank(ParallelMode.PIPELINE):
send_forward(tensor)
logger.info('Rank {} sent forward.'.format(rank))
@@ -54,8 +45,7 @@ def check_backward(output_grad, rank, logger):
grad = output_grad.clone()
else:
grad = recv_backward(output_grad.shape)
logger.info('Rank {} received backward. Correct grad: {}'.format(
rank, check_equal(grad, output_grad)))
logger.info('Rank {} received backward. Correct grad: {}'.format(rank, check_equal(grad, output_grad)))
if not gpc.is_first_rank(ParallelMode.PIPELINE):
send_backward(grad)
logger.info('Rank {} sent backward.'.format(rank))
@@ -65,17 +55,15 @@ def check_forward_backward(output_tensor, output_grad, rank, logger):
dist.barrier()
if not gpc.is_first_rank(ParallelMode.PIPELINE):
tensor = send_backward_recv_forward(output_grad, output_tensor.shape)
logger.info(
'Rank {} sent backward received forward. Correct tensor: {}'.
format(rank, check_equal(tensor, output_tensor)))
logger.info('Rank {} sent backward received forward. Correct tensor: {}'.format(
rank, check_equal(tensor, output_tensor)))
if not gpc.is_last_rank(ParallelMode.PIPELINE):
grad = send_forward_recv_backward(output_tensor, output_grad.shape)
logger.info(
'Rank {} sent forward received backward. Correct grad: {}'.format(
rank, check_equal(grad, output_grad)))
logger.info('Rank {} sent forward received backward. Correct grad: {}'.format(
rank, check_equal(grad, output_grad)))
def check_comm(size, rank, prev_rank, next_rank, logger):
def check_comm(size, rank, prev_rank, next_rank, logger):
dtype = torch.float32
device = get_current_device()
tensor_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
@@ -90,21 +78,12 @@ def check_comm(size, rank, prev_rank, next_rank, logger):
def run_check(rank, world_size, port):
launch(
config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl'
)
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
rank = gpc.get_global_rank()
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
logger.info(
'Rank {0}: prev rank {1}, next rank {2}'.format(
rank, prev_rank, next_rank))
logger.info('Rank {0}: prev rank {1}, next rank {2}'.format(rank, prev_rank, next_rank))
logger.info('Distributed environment is initialzied.')
check_comm(world_size, rank, prev_rank, next_rank, logger)

View File

@@ -17,48 +17,34 @@ from colossalai.utils import free_port, get_dataloader, print_rank_0
from torchvision import transforms
from torchvision.datasets import CIFAR10
import model
BATCH_SIZE = 32
NUM_MICRO = 8
BATCH_SIZE = 4
NUM_MICRO = 2
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, './resnet_config.py')
def run_schedule(rank, world_size, port):
launch(config=CONFIG_PATH,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
launch(config=CONFIG_PATH, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model
model = build_pipeline_model_from_cfg(gpc.config.model, 1)
print_rank_0('model is created')
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(size=32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
0.2023, 0.1994, 0.2010]),
]
)
)
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
train_dataloader = get_dataloader(
dataset=train_dataset,
shuffle=True,
add_sampler=True,
batch_size=BATCH_SIZE,
pin_memory=True,
)
# build criterion
criterion = torch.nn.CrossEntropyLoss()