[zero] add sharded grad and refactor grad hooks for ShardedModel (#287)

This commit is contained in:
ver217
2022-03-02 18:28:29 +08:00
committed by GitHub
parent 4fbb8db586
commit 9b07ac81d4
9 changed files with 305 additions and 75 deletions

View File

@@ -1,9 +1,10 @@
from functools import partial
from operator import imod
from colossalai.utils import checkpoint
import torch.nn as nn
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import checkpoint
LOGGER = get_dist_logger()
@@ -34,6 +35,7 @@ CONFIG = dict(
)
)
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
@@ -61,6 +63,7 @@ class Net(nn.Module):
x = layer(x)
return x
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
@@ -72,7 +75,8 @@ def check_grads(model, zero_model, loose=False):
zero_grad = zero_p.grad.clone().to(p.device)
assert p.grad.dtype == zero_grad.dtype
assert allclose(p.grad, zero_grad, loose=loose)
LOGGER.info(torch.sum(p.grad-zero_grad))
LOGGER.info(torch.sum(p.grad - zero_grad))
def check_params(model, zero_model, loose=False):
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
@@ -80,3 +84,30 @@ def check_params(model, zero_model, loose=False):
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)

View File

@@ -3,19 +3,18 @@
import copy
from functools import partial
from operator import mod
from pyexpat import model
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from tests.test_zero_data_parallel.common import Net, CONFIG, check_grads
from common import CONFIG, Net, check_grads, check_grads_padding
def run_fwd_bwd(model, x, enable_autocast=False):
@@ -24,8 +23,11 @@ def run_fwd_bwd(model, x, enable_autocast=False):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
loss.backward()
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG,
@@ -34,7 +36,7 @@ def run_dist(rank, world_size, port):
host='localhost',
port=port,
backend='nccl')
model = Net(checkpoint=True).cuda()
zero_model = copy.deepcopy(model)
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
@@ -43,7 +45,10 @@ def run_dist(rank, world_size, port):
x = torch.rand(2, 5).cuda()
run_fwd_bwd(zero_model, x, False)
run_fwd_bwd(model, x, False)
check_grads(model, zero_model)
if dist.get_world_size() > 1:
check_grads_padding(model, zero_model)
else:
check_grads(model, zero_model)
@pytest.mark.dist

View File

@@ -14,7 +14,9 @@ from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from common import Net, allclose
from common import Net, check_grads_padding, check_params_padding
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
@@ -26,34 +28,6 @@ def run_step(model, optimizer, x, enable_autocast=False):
loss.backward()
optimizer.step()
def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_grad = zero_p.grad.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(4)
if rank >= len(chunks):
continue
grad = chunks[rank]
if zero_p.zero_shard_padding > 0:
zero_grad = zero_grad[:-zero_p.zero_shard_padding]
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose)
def check_params_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
zero_shard_padding = zero_p.zero_shard_padding
zero_p = zero_p.clone().to(p.device)
chunks = torch.flatten(p).chunk(4)
if rank >= len(chunks):
continue
p = chunks[rank]
if zero_shard_padding > 0:
zero_p = zero_p[:-zero_shard_padding]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose)
def decode_booleans(intval, bits):
res = []