Feature/zero (#279)

* add zero1 (#209)

* add zero1

* add test zero1

* update zero stage 1 develop (#212)

* Implement naive zero3 (#240)

* naive zero3 works well

* add zero3 param manager

* add TODOs in comments

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* fix bugs of hook and add unit tests (#252)

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* polish code and add state dict hook

* fix bug

* update unit test

* refactor reconstructed zero code

* clip_grad support zero3 and add unit test

* add unit test for Zero3ParameterManager

* [WIP] initialize the shard param class

* [WIP] Yet another sharded model implementation (#274)

* [WIP] initialize the shard param class

* [WIP] Yes another implementation of shardModel. Using a better hook method.

* torch.concat -> torch.cat

* fix test_zero_level_1.py::test_zero_level_1 unitest

* remove deepspeed implementation and refactor for the reconstructed zero module

* polish zero dp unittests

Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
Jiarui Fang
2022-03-01 18:17:01 +08:00
committed by GitHub
parent f84b15719f
commit 3280869358
40 changed files with 3912 additions and 6493 deletions

View File

@@ -0,0 +1,82 @@
from functools import partial
from operator import imod
from colossalai.utils import checkpoint
import torch.nn as nn
import torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
LOGGER = get_dist_logger()
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
assert p.grad.dtype == zero_grad.dtype
assert allclose(p.grad, zero_grad, loose=loose)
LOGGER.info(torch.sum(p.grad-zero_grad))
def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
from operator import mod
from pyexpat import model
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
def run_fwd_bwd(model, x, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
for _ in range(2):
x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False)
check_grads(model, zero_model)
@pytest.mark.dist
def test_shard_model_v2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_shard_model_v2()

View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from asyncio.log import logger
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_param import ShardParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG
def run_shard_param_check(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
logger = get_dist_logger()
model = Net()
# add an attribute as ca_attr to hijack the access to param.data
for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size
param.ca_attr = ShardParam(param)
param.ca_attr.shard()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
assert(numel_ref == param_data.numel())
for _, param in model.named_parameters():
param.ca_attr.gather()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
disable_existing_loggers([logger])
@pytest.mark.dist
def test_run_shard_shape():
world_size = 2
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape()

View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from common import Net, check_grads, check_params, check_params
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads(model, zero_model)
check_params(model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads(model, zero_model, loose=True)
check_params(model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()

View File

@@ -0,0 +1,123 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from common import Net, allclose
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads_padding(ddp_model, zero_model)
check_params_padding(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads_padding(ddp_model, zero_model, loose=True)
check_params_padding(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
if dist.get_rank() == 0:
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()

View File

@@ -1,102 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_dataloader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 224
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=2,
cpu_offload=True,
verbose=False,
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
# build model
model = resnet18(num_classes=10)
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
# train
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda()
label = label.cuda()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_zero_level_2():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_2()

View File

@@ -1,114 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_dataloader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 16
IMG_SIZE = 224
CONFIG = dict(
fp16=dict(
mode=None,
),
zero=dict(
level=3,
verbose=False,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
fast_init=False
),
offload_param_config=dict(
device='cpu',
pin_memory=True,
buffer_count=5,
buffer_size=1e8,
max_in_cpu=1e9
)
),
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=1, mode=None)
)
)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
# build model
model = resnet18(num_classes=10)
# build dataloader# build dataloaders
train_dataset = CIFAR10(
root=Path(os.environ['DATA']),
download=True,
transform=transforms.Compose(
[
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
)
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
# build optimizer and loss
# optimizer = build_optimizer(global_context.config.optimizer, model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
# train
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda()
label = label.cuda()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
break
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from colossalai.utils import free_port
from common import CONFIG
def run_shard_shape_check(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = torch.nn.Linear(2, 4 * world_size)
gpc.init_parallel_groups()
Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config'))
assert(model.weight.numel() == 4 * 2)
assert(model.bias.numel() == 4)
@pytest.mark.dist
def test_run_shard_shape():
world_size = 2
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape()