mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -8,8 +8,9 @@ from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import distribute_tensor
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
@@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
|
Reference in New Issue
Block a user