mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-21 10:50:56 +00:00
[tensor] impl ColoDDP for ColoTensor (#1009)
* impl ColoDDP for ColoTensor * polish code
This commit is contained in:
parent
ae7c338105
commit
cefc29ff06
78
colossalai/nn/parallel.py
Normal file
78
colossalai/nn/parallel.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
__all__ = ['ColoDDP']
|
||||||
|
|
||||||
|
|
||||||
|
def free_storage(data: torch.Tensor) -> None:
|
||||||
|
"""Free underlying storage of a Tensor."""
|
||||||
|
if data.storage().size() > 0:
|
||||||
|
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
||||||
|
# is the sole occupant of the Storage.
|
||||||
|
assert data.storage_offset() == 0
|
||||||
|
data.storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
class ColoDDP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||||
|
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
for p in module.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
p.register_hook(partial(self.grad_handle, p))
|
||||||
|
|
||||||
|
def parameters(self, recurse: bool = True):
|
||||||
|
return self.module.parameters(recurse)
|
||||||
|
|
||||||
|
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
||||||
|
return self.module.named_parameters(prefix, recurse)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
self.module.zero_grad(set_to_none=True)
|
||||||
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor):
|
||||||
|
loss.backward()
|
||||||
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||||
|
for p in self.module.parameters():
|
||||||
|
p.grad = p._saved_grad
|
||||||
|
|
||||||
|
def grad_handle(self, p, grad):
|
||||||
|
empty_grad = torch.empty_like(grad)
|
||||||
|
free_storage(empty_grad)
|
||||||
|
if self.dp_world_size > 1:
|
||||||
|
grad = grad / self.dp_world_size
|
||||||
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
dist.all_reduce(grad, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
ColoDDP._save_grad(p, grad)
|
||||||
|
grad.record_stream(self.comm_stream)
|
||||||
|
else:
|
||||||
|
ColoDDP._save_grad(p, grad)
|
||||||
|
return empty_grad
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save_grad(p, grad):
|
||||||
|
if hasattr(p, '_saved_grad'):
|
||||||
|
p._saved_grad.add_(grad)
|
||||||
|
else:
|
||||||
|
p._saved_grad = grad
|
||||||
|
|
||||||
|
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||||
|
self.module.zero_grad(set_to_none=True)
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if getattr(p, '_saved_grad', None) is not None:
|
||||||
|
if set_to_none:
|
||||||
|
p._saved_grad = None
|
||||||
|
else:
|
||||||
|
if p._saved_grad.grad_fn is not None:
|
||||||
|
p._saved_grad.detach_()
|
||||||
|
else:
|
||||||
|
p._saved_grad.requires_grad_(False)
|
||||||
|
p._saved_grad.zero_()
|
@ -9,8 +9,10 @@ from colossalai.utils import ColoInitContext
|
|||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import tensor_equal, tensor_shard_equal
|
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from colossalai.nn.parallel import ColoDDP
|
||||||
|
|
||||||
|
|
||||||
def init_1d_row_spec(model):
|
def init_1d_row_spec(model):
|
||||||
@ -43,7 +45,7 @@ def check_grad_equal(model, torch_model):
|
|||||||
assert tensor_shard_equal(torch_p.grad, p.grad)
|
assert tensor_shard_equal(torch_p.grad, p.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_gpt(init_spec_func):
|
def run_gpt(init_spec_func, use_ddp):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
@ -51,18 +53,27 @@ def run_gpt(init_spec_func):
|
|||||||
model = model_builder()
|
model = model_builder()
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
torch_model = model_builder().cuda()
|
torch_model = model_builder().cuda()
|
||||||
|
if use_ddp:
|
||||||
|
model = ColoDDP(model)
|
||||||
|
torch_model = DDP(torch_model,
|
||||||
|
device_ids=[gpc.get_global_rank()],
|
||||||
|
process_group=gpc.get_group(ParallelMode.DATA))
|
||||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||||
torch_p.data.copy_(p)
|
torch_p.data.copy_(p)
|
||||||
init_spec_func(model)
|
init_spec_func(model)
|
||||||
check_param_equal(model, torch_model)
|
check_param_equal(model, torch_model)
|
||||||
model.train()
|
model.train()
|
||||||
torch_model.train()
|
torch_model.train()
|
||||||
|
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||||
logits = model(input_ids, attn_mask)
|
logits = model(input_ids, attn_mask)
|
||||||
torch_logits = torch_model(input_ids, attn_mask)
|
torch_logits = torch_model(input_ids, attn_mask)
|
||||||
assert tensor_equal(torch_logits, logits)
|
assert tensor_equal(torch_logits, logits)
|
||||||
loss = criterion(logits, input_ids)
|
loss = criterion(logits, input_ids)
|
||||||
torch_loss = criterion(torch_logits, input_ids)
|
torch_loss = criterion(torch_logits, input_ids)
|
||||||
|
if use_ddp:
|
||||||
|
model.backward(loss)
|
||||||
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch_loss.backward()
|
torch_loss.backward()
|
||||||
check_grad_equal(model, torch_model)
|
check_grad_equal(model, torch_model)
|
||||||
@ -70,18 +81,22 @@ def run_gpt(init_spec_func):
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port, use_ddp):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
if use_ddp and world_size == 1:
|
||||||
|
return
|
||||||
|
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||||
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_gpt(init_1d_row_spec)
|
run_gpt(init_1d_row_spec, use_ddp)
|
||||||
run_gpt(init_1d_col_spec)
|
run_gpt(init_1d_col_spec, use_ddp)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_gpt(world_size):
|
def test_gpt(world_size, use_ddp):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user