[tensor] ColoTensor supports ZeRo (#1015)

* impl chunk manager

* impl param op hook

* add reduce_chunk

* add zero hook v2

* add zero dp

* fix TensorInfo

* impl load balancing when using zero without chunk

* fix zero hook

* polish chunk

* fix bugs

* ddp ok

* zero ok

* polish code

* fix bugs about load balancing

* polish code

* polish code

* add ene-to-end test

* polish code

* polish code

* polish code

* fix typo

* add test_chunk

* fix bugs

* fix bugs

* polish code
This commit is contained in:
ver217
2022-05-31 12:00:12 +08:00
committed by GitHub
parent cfa6c1b46b
commit 9492a561c3
8 changed files with 618 additions and 4 deletions

View File

@@ -3,8 +3,10 @@ import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState
__all__ = ['ColoDDP']
__all__ = ['ColoDDP', 'ColoDDPV2']
def free_storage(data: torch.Tensor) -> None:
@@ -76,3 +78,54 @@ class ColoDDP(torch.nn.Module):
else:
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()
class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
super().__init__(module)
self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager)
self.fp32_params = []
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
fp32_p = p.float()
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
if not self.chunk_manager.is_chunk_free(p):
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)