[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
This commit is contained in:
Hongxin Liu
2023-07-10 10:48:53 +08:00
parent f3bcc292c8
commit 890774b2fb
25 changed files with 263 additions and 157 deletions

View File

@@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.nn import Parameter
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
@@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, '_is_param')
tensor.data = target
tensor.requires_grad = target.requires_grad
# subclass of torch.Tensor does not have tolist() method
@@ -198,10 +202,10 @@ class LazyTensor(torch.Tensor):
def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
"""
self._factory_method = None
self._op_buffer = None
self._materialized_data = None
self._meta_data = None
delattr(self, '_factory_method')
delattr(self, '_op_buffer')
delattr(self, '_materialized_data')
delattr(self, '_meta_data')
@staticmethod
def _replace_with_materialized(x):
@@ -350,20 +354,19 @@ class LazyTensor(torch.Tensor):
def factory_fn():
# if self is materialized, return self
new_tensor = self.materialize() if type(self) is LazyTensor else self
copied = new_tensor.detach().clone()
if new_tensor.requires_grad:
copied.requires_grad_()
return copied
return _copy_tensor(new_tensor, new_tensor.requires_grad)
if self._materialized_data is not None:
# self is early materialized
copied = self._materialized_data.detach().clone()
if self.requires_grad:
copied.requires_grad_()
copied = _copy_tensor(self._materialized_data, self.requires_grad)
target = LazyTensor(lambda: None, concrete_data=copied)
else:
target = LazyTensor(factory_fn, meta_data=self._meta_data)
if isinstance(self, Parameter):
# hack isinstance check of parameter
target._is_param = True
memo[id(self)] = target
return target
@@ -408,6 +411,10 @@ class LazyTensor(torch.Tensor):
def __hash__(self):
return id(self)
def __rpow__(self, other):
dtype = torch.result_type(self, other)
return torch.tensor(other, dtype=dtype, device=self.device)**self
class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
@@ -536,7 +543,7 @@ class LazyInitContext:
@staticmethod
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
"""Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
@@ -553,7 +560,7 @@ class LazyInitContext:
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
module (nn.Module): Target ``nn.Module``
@@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool:
if not isinstance(x, int):
return False
return True
def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
copied = tensor.data.clone()
copied.requires_grad = requires_grad
return copied