[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
This commit is contained in:
Frank Lee
2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@@ -8,8 +8,9 @@ from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import distribute_tensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, layout).local_tensor
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor)
def clean(self) -> None: