Feature/zero (#279)

* add zero1 (#209)

* add zero1

* add test zero1

* update zero stage 1 develop (#212)

* Implement naive zero3 (#240)

* naive zero3 works well

* add zero3 param manager

* add TODOs in comments

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* fix bugs of hook and add unit tests (#252)

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* polish code and add state dict hook

* fix bug

* update unit test

* refactor reconstructed zero code

* clip_grad support zero3 and add unit test

* add unit test for Zero3ParameterManager

* [WIP] initialize the shard param class

* [WIP] Yet another sharded model implementation (#274)

* [WIP] initialize the shard param class

* [WIP] Yes another implementation of shardModel. Using a better hook method.

* torch.concat -> torch.cat

* fix test_zero_level_1.py::test_zero_level_1 unitest

* remove deepspeed implementation and refactor for the reconstructed zero module

* polish zero dp unittests

Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
Jiarui Fang
2022-03-01 18:17:01 +08:00
committed by Frank Lee
parent 08eccfe681
commit 5a560a060a
40 changed files with 3912 additions and 6493 deletions

View File

@@ -0,0 +1,187 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
import operator as op
from functools import partial, reduce
from typing import List
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
class Enumerator:
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
self.arg_names = arg_names
self.enums = Enumerator.all_enumerate(arg_values)
def __len__(self):
return len(self.enums)
def __getitem__(self, idx):
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
@staticmethod
def all_enumerate(args: List[tuple]):
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
idxs = [0] * len(args)
states = []
for _ in range(num_states):
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
if len(states) == num_states:
break
i = 0
while idxs[i] + 1 == len(args[i]):
idxs[i] = 0
i += 1
idxs[i] += 1
return states
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
clip_grad(model, norm_type)
optimizer.step()
def clip_grad(model, norm_type):
if isinstance(model, DDP):
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
else:
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
return torch.allclose(tensor_a, tensor_b)
def check_grads(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
check_grads(ddp_model, zero_model)
check_params(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
check_grads(ddp_model, zero_model, loose=True)
check_params(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
arg_enumerator = Enumerator(args, arg_values)
for kwargs in arg_enumerator:
if dist.get_rank() == 0:
print(kwargs)
check_config(**kwargs)
check_config()
@ pytest.mark.dist
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_clip_grad()