mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
[tensor] ColoTensor supports ZeRo (#1015)
* impl chunk manager * impl param op hook * add reduce_chunk * add zero hook v2 * add zero dp * fix TensorInfo * impl load balancing when using zero without chunk * fix zero hook * polish chunk * fix bugs * ddp ok * zero ok * polish code * fix bugs about load balancing * polish code * polish code * add ene-to-end test * polish code * polish code * polish code * fix typo * add test_chunk * fix bugs * fix bugs * polish code
This commit is contained in:
@@ -3,8 +3,10 @@ import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState
|
||||
|
||||
__all__ = ['ColoDDP']
|
||||
__all__ = ['ColoDDP', 'ColoDDPV2']
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
@@ -76,3 +78,54 @@ class ColoDDP(torch.nn.Module):
|
||||
else:
|
||||
p._saved_grad.requires_grad_(False)
|
||||
p._saved_grad.zero_()
|
||||
|
||||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
|
||||
super().__init__(module)
|
||||
self.chunk_manager = chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(chunk_manager)
|
||||
self.fp32_params = []
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert p.dtype == torch.half
|
||||
fp32_p = p.float()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
self.fp32_params.append(fp32_p)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
|
||||
if not self.chunk_manager.is_chunk_free(p):
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
|
||||
with use_param_op_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return outputs
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
for p in self.module.parameters():
|
||||
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
with torch._C.DisableTorchFunction():
|
||||
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
|
||||
if self.dp_world_size > 1:
|
||||
grad = grad / self.dp_world_size
|
||||
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
|
||||
self.chunk_manager.reduce_chunk(p)
|
||||
self.chunk_manager.release_chunk(p)
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
Reference in New Issue
Block a user