[zero] add zero optimizer for ColoTensor (#1046)

* add zero optimizer

* torch ok

* unit test ok

* polish code

* fix bugs

* polish unit test

* polish zero optim

* polish colo ddp v2

* refactor folder structure

* add comment

* polish unit test

* polish zero optim

* polish unit test
This commit is contained in:
ver217
2022-06-02 12:13:15 +08:00
committed by GitHub
parent e32470b6de
commit 51b9a49655
7 changed files with 245 additions and 15 deletions

View File

@@ -8,5 +8,6 @@ from .lars import Lars
from .cpu_adam import CPUAdam
from .hybrid_adam import HybridAdam
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD',
'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT']
__all__ = [
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
]

View File

@@ -4,7 +4,8 @@ from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from functools import partial
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor import ChunkManager, use_param_op_hooks, TensorState
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks
__all__ = ['ColoDDP', 'ColoDDPV2']
@@ -87,27 +88,23 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager = chunk_manager
self.param_op_hook = ZeROHookV2(chunk_manager)
self.fp32_params = []
self.overflow_counter = 0
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
fp32_p = p.float()
fp32_p = p.float().detach()
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True)
for p, fp32_p in zip(self.module.parameters(), self.fp32_params):
if not self.chunk_manager.is_chunk_free(p):
self.chunk_manager.copy_tensor_to_chunk_slice(p, fp32_p)
with use_param_op_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
def _post_backward(self):
self.chunk_manager.exec_lazy_release()
for p in self.module.parameters():
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
@@ -115,6 +112,16 @@ class ColoDDPV2(ColoDDP):
else:
p.grad = p.data
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
@@ -123,8 +130,11 @@ class ColoDDPV2(ColoDDP):
if self.dp_world_size > 1:
grad = grad / self.dp_world_size
self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
self.chunk_manager.reduce_chunk(p)
chunk = self.chunk_manager.get_chunk(p)
reduced = self.chunk_manager.reduce_chunk(p)
self.chunk_manager.release_chunk(p)
if reduced and not chunk.is_free:
self.overflow_counter += chunk.has_inf_or_nan
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None: