mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[shardformer] support lazy init (#4202)
* [shardformer] support lazy init * [shardformer] linear support lazy init * [shardformer] embedding support lazy init * [shardformer] norm support lazy init * [shardformer] fused linear support lazy init * [test] update shardformer test layer * [test] shardformer with lazy init fit ddp * [lazy] hotfix deepcopy of param * [shardformer] fix bert policy and update test * [shardformer] fix bloom policy and update test * [shardformer] fix opt policy and update test * [shardformer] fix t5 policy and update test * [shardformer] fix gpt2 policy and update test * [shardformer] fix llama policy and update test
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
@@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
torch.Tensor: the converted tensor
|
||||
"""
|
||||
cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
|
||||
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
|
||||
tensor.__class__ = cls_to_become
|
||||
if cls_to_become is Parameter:
|
||||
# to fit UninitializedParameter
|
||||
delattr(tensor, '_is_param')
|
||||
tensor.data = target
|
||||
tensor.requires_grad = target.requires_grad
|
||||
# subclass of torch.Tensor does not have tolist() method
|
||||
@@ -198,10 +202,10 @@ class LazyTensor(torch.Tensor):
|
||||
def clean(self) -> None:
|
||||
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
|
||||
"""
|
||||
self._factory_method = None
|
||||
self._op_buffer = None
|
||||
self._materialized_data = None
|
||||
self._meta_data = None
|
||||
delattr(self, '_factory_method')
|
||||
delattr(self, '_op_buffer')
|
||||
delattr(self, '_materialized_data')
|
||||
delattr(self, '_meta_data')
|
||||
|
||||
@staticmethod
|
||||
def _replace_with_materialized(x):
|
||||
@@ -350,20 +354,19 @@ class LazyTensor(torch.Tensor):
|
||||
def factory_fn():
|
||||
# if self is materialized, return self
|
||||
new_tensor = self.materialize() if type(self) is LazyTensor else self
|
||||
copied = new_tensor.detach().clone()
|
||||
if new_tensor.requires_grad:
|
||||
copied.requires_grad_()
|
||||
return copied
|
||||
return _copy_tensor(new_tensor, new_tensor.requires_grad)
|
||||
|
||||
if self._materialized_data is not None:
|
||||
# self is early materialized
|
||||
copied = self._materialized_data.detach().clone()
|
||||
if self.requires_grad:
|
||||
copied.requires_grad_()
|
||||
copied = _copy_tensor(self._materialized_data, self.requires_grad)
|
||||
target = LazyTensor(lambda: None, concrete_data=copied)
|
||||
else:
|
||||
target = LazyTensor(factory_fn, meta_data=self._meta_data)
|
||||
|
||||
if isinstance(self, Parameter):
|
||||
# hack isinstance check of parameter
|
||||
target._is_param = True
|
||||
|
||||
memo[id(self)] = target
|
||||
return target
|
||||
|
||||
@@ -408,6 +411,10 @@ class LazyTensor(torch.Tensor):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __rpow__(self, other):
|
||||
dtype = torch.result_type(self, other)
|
||||
return torch.tensor(other, dtype=dtype, device=self.device)**self
|
||||
|
||||
|
||||
class LazyInitContext:
|
||||
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
||||
@@ -536,7 +543,7 @@ class LazyInitContext:
|
||||
|
||||
@staticmethod
|
||||
def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
|
||||
"""Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
"""Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
@@ -553,7 +560,7 @@ class LazyInitContext:
|
||||
device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: Dict[str, ShardingSpec],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
"""Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Target ``nn.Module``
|
||||
@@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool:
|
||||
if not isinstance(x, int):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
|
||||
copied = tensor.data.clone()
|
||||
copied.requires_grad = requires_grad
|
||||
return copied
|
||||
|
Reference in New Issue
Block a user