mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
Feature/zero (#279)
* add zero1 (#209) * add zero1 * add test zero1 * update zero stage 1 develop (#212) * Implement naive zero3 (#240) * naive zero3 works well * add zero3 param manager * add TODOs in comments * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * fix bugs of hook and add unit tests (#252) * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * polish code and add state dict hook * fix bug * update unit test * refactor reconstructed zero code * clip_grad support zero3 and add unit test * add unit test for Zero3ParameterManager * [WIP] initialize the shard param class * [WIP] Yet another sharded model implementation (#274) * [WIP] initialize the shard param class * [WIP] Yes another implementation of shardModel. Using a better hook method. * torch.concat -> torch.cat * fix test_zero_level_1.py::test_zero_level_1 unitest * remove deepspeed implementation and refactor for the reconstructed zero module * polish zero dp unittests Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
187
tests/test_utils/test_zero_gradient_clippling.py
Normal file
187
tests/test_utils/test_zero_gradient_clippling.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import operator as op
|
||||
from functools import partial, reduce
|
||||
from typing import List
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
|
||||
class Enumerator:
|
||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
||||
self.arg_names = arg_names
|
||||
self.enums = Enumerator.all_enumerate(arg_values)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.enums)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
|
||||
|
||||
@staticmethod
|
||||
def all_enumerate(args: List[tuple]):
|
||||
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
|
||||
idxs = [0] * len(args)
|
||||
states = []
|
||||
for _ in range(num_states):
|
||||
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
|
||||
if len(states) == num_states:
|
||||
break
|
||||
i = 0
|
||||
while idxs[i] + 1 == len(args[i]):
|
||||
idxs[i] = 0
|
||||
i += 1
|
||||
idxs[i] += 1
|
||||
return states
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
clip_grad(model, norm_type)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def clip_grad(model, norm_type):
|
||||
if isinstance(model, DDP):
|
||||
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
else:
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model)
|
||||
check_params(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model, loose=True)
|
||||
check_params(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
||||
arg_enumerator = Enumerator(args, arg_values)
|
||||
|
||||
for kwargs in arg_enumerator:
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
check_config()
|
||||
|
||||
|
||||
@ pytest.mark.dist
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad()
|
82
tests/test_zero_data_parallel/common.py
Normal file
82
tests/test_zero_data_parallel/common.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from functools import partial
|
||||
from operator import imod
|
||||
from colossalai.utils import checkpoint
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
||||
LOGGER = get_dist_logger()
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False
|
||||
),
|
||||
offload_param_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9
|
||||
)
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
assert p.grad.dtype == zero_grad.dtype
|
||||
assert allclose(p.grad, zero_grad, loose=loose)
|
||||
LOGGER.info(torch.sum(p.grad-zero_grad))
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
57
tests/test_zero_data_parallel/test_shard_model_v2.py
Normal file
57
tests/test_zero_data_parallel/test_shard_model_v2.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from operator import mod
|
||||
from pyexpat import model
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
|
||||
|
||||
|
||||
def run_fwd_bwd(model, x, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = Net(checkpoint=True).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_fwd_bwd(zero_model, x, False)
|
||||
run_fwd_bwd(model, x, False)
|
||||
check_grads(model, zero_model)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_shard_model_v2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_model_v2()
|
50
tests/test_zero_data_parallel/test_shard_param.py
Normal file
50
tests/test_zero_data_parallel/test_shard_param.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||
|
||||
def run_shard_param_check(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
model = Net()
|
||||
|
||||
# add an attribute as ca_attr to hijack the access to param.data
|
||||
for _, param in model.named_parameters():
|
||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr.shard()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
|
||||
assert(numel_ref == param_data.numel())
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
param.ca_attr.gather()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
|
||||
|
||||
disable_existing_loggers([logger])
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_run_shard_shape():
|
||||
world_size = 2
|
||||
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run_shard_shape()
|
119
tests/test_zero_data_parallel/test_zero_dev_3.py
Normal file
119
tests/test_zero_data_parallel/test_zero_dev_3.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from common import Net, check_grads, check_params, check_params
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads(model, zero_model)
|
||||
check_params(model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads(model, zero_model, loose=True)
|
||||
check_params(model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 1
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
123
tests/test_zero_data_parallel/test_zero_dev_3_mp4.py
Normal file
123
tests/test_zero_data_parallel/test_zero_dev_3_mp4.py
Normal file
@@ -0,0 +1,123 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from common import Net, allclose
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def check_grads_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def decode_booleans(intval, bits):
|
||||
res = []
|
||||
for bit in range(bits):
|
||||
mask = 1 << bit
|
||||
res.append((intval & mask) == mask)
|
||||
return res
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
|
||||
check_grads_padding(ddp_model, zero_model)
|
||||
check_params_padding(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
|
||||
check_grads_padding(ddp_model, zero_model, loose=True)
|
||||
check_params_padding(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload']
|
||||
|
||||
def pack_args(i):
|
||||
booleans = decode_booleans(i, len(args))
|
||||
return {arg: booleans[idx] for idx, arg in enumerate(args)}
|
||||
|
||||
for j in range(2 ** len(args)):
|
||||
kwargs = pack_args(j)
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
@@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=2,
|
||||
cpu_offload=True,
|
||||
verbose=False,
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer and loss
|
||||
# optimizer = build_optimizer(global_context.config.optimizer, model)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
# train
|
||||
model.train()
|
||||
for idx, (data, label) in enumerate(train_dataloader):
|
||||
engine.zero_grad()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_2():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_2()
|
@@ -1,114 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_dataloader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
BATCH_SIZE = 16
|
||||
IMG_SIZE = 224
|
||||
|
||||
CONFIG = dict(
|
||||
fp16=dict(
|
||||
mode=None,
|
||||
),
|
||||
zero=dict(
|
||||
level=3,
|
||||
verbose=False,
|
||||
offload_optimizer_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
fast_init=False
|
||||
),
|
||||
offload_param_config=dict(
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
buffer_count=5,
|
||||
buffer_size=1e8,
|
||||
max_in_cpu=1e9
|
||||
)
|
||||
),
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloader# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer and loss
|
||||
# optimizer = build_optimizer(global_context.config.optimizer, model)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
# train
|
||||
model.train()
|
||||
for idx, (data, label) in enumerate(train_dataloader):
|
||||
engine.zero_grad()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_level_3()
|
41
tests/test_zero_data_parallel/test_zero_param_mgr.py
Normal file
41
tests/test_zero_data_parallel/test_zero_param_mgr.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.sharded_model.param_manager import Zero3ParameterManager
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.utils import free_port
|
||||
from common import CONFIG
|
||||
|
||||
def run_shard_shape_check(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = torch.nn.Linear(2, 4 * world_size)
|
||||
gpc.init_parallel_groups()
|
||||
Zero3ParameterManager(module=model, process_group=gpc.get_group(ParallelMode.DATA), offload_config=CONFIG.get('offload_param_config'))
|
||||
|
||||
assert(model.weight.numel() == 4 * 2)
|
||||
assert(model.bias.numel() == 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_run_shard_shape():
|
||||
world_size = 2
|
||||
run_func = partial(run_shard_shape_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run_shard_shape()
|
@@ -88,6 +88,7 @@ def run_2d_parallel_vision_transformer_level_2(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
||||
def test_2d_vit_zero_level_2():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_2, world_size=world_size, port=free_port())
|
||||
|
@@ -88,7 +88,7 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size, port):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||
@pytest.mark.skip(reason="This test should be refactored for the reconstructed zero")
|
||||
def test_3d_vit_zero_level_3():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size, port=free_port())
|
||||
|
Reference in New Issue
Block a user