mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)
This commit is contained in:
79
tests/test_utils/test_norm_gradient_clipping.py
Normal file
79
tests/test_utils/test_norm_gradient_clipping.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from functools import partial
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils.common import clip_grad_norm
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
|
||||
return abs(num - other) <= atol + rtol * other
|
||||
|
||||
|
||||
def shard_param(p: ColoParameter) -> None:
|
||||
pg = p.get_process_group()
|
||||
p._redistribute(distspec.shard([0], [pg.tp_world_size()]))
|
||||
p.grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()].clone().detach()
|
||||
|
||||
|
||||
def check_grad_equal(p: Parameter, colo_p: ColoParameter) -> None:
|
||||
pg = colo_p.get_process_group()
|
||||
if p.shape != colo_p.shape:
|
||||
grad = p.grad.chunk(pg.tp_world_size(), 0)[pg.tp_local_rank()]
|
||||
else:
|
||||
grad = p.grad
|
||||
assert torch.allclose(grad, colo_p.grad), f'diff: {torch.abs(grad - colo_p.grad)}'
|
||||
|
||||
|
||||
@parameterize('dtype', [torch.float])
|
||||
@parameterize('device', ['mixed', 'cuda', 'cpu'])
|
||||
@parameterize('norm_type', [2.0, 3.0, float('inf')])
|
||||
def run_grad_clip_norm(world_size: int, dtype: torch.dtype, device: str, norm_type: float):
|
||||
print(f'{world_size}, {dtype}, {device}, {norm_type}')
|
||||
cuda_device = get_current_device()
|
||||
devices = [cuda_device] * 4
|
||||
if device == 'cpu':
|
||||
devices = [torch.device('cpu')] * 4
|
||||
elif device == 'mixed':
|
||||
devices = [cuda_device] * 2 + [torch.device('cpu')] * 2
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
params = [Parameter(torch.empty(4, 4, dtype=dtype, device=devices[i])) for i in range(4)]
|
||||
colo_params = [
|
||||
ColoParameter(torch.empty(4, 4, dtype=dtype, device=devices[i]), spec=ColoTensorSpec(pg)) for i in range(4)
|
||||
]
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
grad = torch.rand_like(p)
|
||||
p.grad = grad
|
||||
colo_p.grad = grad.clone().detach()
|
||||
shard_param(colo_params[0])
|
||||
shard_param(colo_params[2])
|
||||
torch_norm = clip_grad_norm_(params, 1.0, norm_type=norm_type)
|
||||
colo_norm = clip_grad_norm(colo_params, 1.0, norm_type=norm_type)
|
||||
assert close(torch_norm, colo_norm), f'diff: {abs(torch_norm-colo_norm)}'
|
||||
for p, colo_p in zip(params, colo_params):
|
||||
check_grad_equal(p, colo_p)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_grad_clip_norm(world_size=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad(world_size: int):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_clip_grad(2)
|
Reference in New Issue
Block a user