mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
62
tests/test_legacy/test_engine/test_engine.py
Normal file
62
tests/test_legacy/test_engine/test_engine.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
fp16=dict(mode=None),
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
|
||||
@parameterize('model_name', ['repeated_computed_layers', 'resnet18', 'repeated_computed_layers'])
|
||||
@parameterize('amp_mode', [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
|
||||
def run_train(model_name, amp_mode):
|
||||
# FIXME: test bert
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
gpc.config.fp16['mode'] = amp_mode
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
model = model_builder(checkpoint=False)
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer_class(model.parameters(), lr=1e-3),
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
try:
|
||||
engine.train()
|
||||
for data, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
if criterion:
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
else:
|
||||
loss = engine(data, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
except IndexError:
|
||||
# if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue
|
||||
# the following check fails in apex
|
||||
# if cached_x.grad_fn.next_functions[1][0].variable is not x:
|
||||
pass
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_train()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_engine():
|
||||
spawn(run_engine, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
95
tests/test_legacy/test_engine/test_gradient_accumluation.py
Normal file
95
tests/test_legacy/test_engine/test_gradient_accumluation.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 2
|
||||
NUM_CLASSES = 10
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
clip_grad_norm=1.0,
|
||||
gradient_accumulation=4)
|
||||
|
||||
|
||||
def run_no_pipeline(rank, world_size, port):
|
||||
|
||||
# init dist env
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]))
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
param_track = []
|
||||
grad_track = []
|
||||
next(model.parameters()).retain_grad()
|
||||
|
||||
engine.train()
|
||||
step = 0
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# check
|
||||
param_track.append(next(model.parameters())[0].clone())
|
||||
grad_track.append(next(model.parameters()).grad[0].clone())
|
||||
step += 1
|
||||
if step == CONFIG['gradient_accumulation']:
|
||||
break
|
||||
|
||||
assert not torch.all(grad_track[0] == grad_track[-1]), 'grad should be different in different iterations'
|
||||
assert torch.all(param_track[0] == param_track[1]) and not torch.all(param_track[0] == param_track[-1]), \
|
||||
'param should be the same in the first few iterations and only changed in the last iteration'
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_engine():
|
||||
spawn(run_no_pipeline, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
Reference in New Issue
Block a user