mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-04 23:24:25 +00:00
Feature/zero (#279)
* add zero1 (#209) * add zero1 * add test zero1 * update zero stage 1 develop (#212) * Implement naive zero3 (#240) * naive zero3 works well * add zero3 param manager * add TODOs in comments * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * fix bugs of hook and add unit tests (#252) * add gather full param ctx * fix sub module streams * add offload * fix bugs of hook and add unit tests * polish code and add state dict hook * fix bug * update unit test * refactor reconstructed zero code * clip_grad support zero3 and add unit test * add unit test for Zero3ParameterManager * [WIP] initialize the shard param class * [WIP] Yet another sharded model implementation (#274) * [WIP] initialize the shard param class * [WIP] Yes another implementation of shardModel. Using a better hook method. * torch.concat -> torch.cat * fix test_zero_level_1.py::test_zero_level_1 unitest * remove deepspeed implementation and refactor for the reconstructed zero module * polish zero dp unittests Co-authored-by: ver217 <lhx0217@gmail.com> Co-authored-by: Frank Lee <somerlee.9@gmail.com>
This commit is contained in:
187
tests/test_utils/test_zero_gradient_clippling.py
Normal file
187
tests/test_utils/test_zero_gradient_clippling.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import operator as op
|
||||
from functools import partial, reduce
|
||||
from typing import List
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||
from colossalai.zero.sharded_model import ShardedModel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
|
||||
class Enumerator:
|
||||
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
|
||||
self.arg_names = arg_names
|
||||
self.enums = Enumerator.all_enumerate(arg_values)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.enums)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {name: self.enums[idx][i] for i, name in enumerate(self.arg_names)}
|
||||
|
||||
@staticmethod
|
||||
def all_enumerate(args: List[tuple]):
|
||||
num_states = reduce(op.mul, map(lambda xs: len(xs), args))
|
||||
idxs = [0] * len(args)
|
||||
states = []
|
||||
for _ in range(num_states):
|
||||
states.append(tuple(args[j][idx] for j, idx in enumerate(idxs)))
|
||||
if len(states) == num_states:
|
||||
break
|
||||
i = 0
|
||||
while idxs[i] + 1 == len(args[i]):
|
||||
idxs[i] = 0
|
||||
i += 1
|
||||
idxs[i] += 1
|
||||
return states
|
||||
|
||||
|
||||
def checkpoint_wrapper(module, enable=True):
|
||||
if enable:
|
||||
module.forward = partial(checkpoint, module.forward)
|
||||
return module
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 1)
|
||||
if checkpoint:
|
||||
self.fc1 = checkpoint_wrapper(self.fc1)
|
||||
self.layers = [
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc1,
|
||||
self.fc2,
|
||||
self.fc3
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False, norm_type=2.0):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
loss.backward()
|
||||
clip_grad(model, norm_type)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def clip_grad(model, norm_type):
|
||||
if isinstance(model, DDP):
|
||||
clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
else:
|
||||
clip_grad_norm_fp32(model.parameters(), max_norm=1.0, norm_type=norm_type)
|
||||
|
||||
|
||||
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
|
||||
if loose:
|
||||
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
|
||||
return torch.allclose(tensor_a, tensor_b)
|
||||
|
||||
|
||||
def check_grads(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_grad = zero_p.grad.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
grad = chunks[rank]
|
||||
if zero_p.zero_shard_padding > 0:
|
||||
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose)
|
||||
|
||||
|
||||
def check_params(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_shard_padding = zero_p.zero_shard_padding
|
||||
zero_p = zero_p.clone().to(p.device)
|
||||
chunks = torch.flatten(p).chunk(4)
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
p = chunks[rank]
|
||||
if zero_shard_padding > 0:
|
||||
zero_p = zero_p[:-zero_shard_padding]
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose)
|
||||
|
||||
|
||||
def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
|
||||
model = Net(checkpoint=checkpoint).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
ddp_model = DDP(model)
|
||||
|
||||
offload_config = {}
|
||||
if offload:
|
||||
offload_config['device'] = 'cpu'
|
||||
zero_model = zero_model.cpu()
|
||||
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
|
||||
|
||||
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model)
|
||||
check_params(ddp_model, zero_model)
|
||||
for _ in range(5):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(ddp_model, optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
run_step(zero_model, zero_optimizer, x, enable_autocast=False, norm_type=norm_type)
|
||||
check_grads(ddp_model, zero_model, loose=True)
|
||||
check_params(ddp_model, zero_model, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
args = ['checkpoint', 'fp16', 'offload', 'norm_type']
|
||||
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
|
||||
arg_enumerator = Enumerator(args, arg_values)
|
||||
|
||||
for kwargs in arg_enumerator:
|
||||
if dist.get_rank() == 0:
|
||||
print(kwargs)
|
||||
check_config(**kwargs)
|
||||
check_config()
|
||||
|
||||
|
||||
@ pytest.mark.dist
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad()
|
||||
Reference in New Issue
Block a user