Revert "[sync] sync feature/shardformer with develop"

This commit is contained in:
Frank Lee
2023-06-09 09:41:27 +08:00
committed by GitHub
parent 24651fdd4f
commit ddcf58cacf
48 changed files with 445 additions and 3876 deletions

View File

@@ -1,5 +1,5 @@
from types import MethodType
from typing import Callable, Dict, Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
@@ -8,9 +8,8 @@ from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@@ -173,7 +172,7 @@ class LazyTensor(torch.Tensor):
self.clean()
return _convert_cls(self, target)
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
def distribute(self, layout: Layout) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
@@ -184,7 +183,7 @@ class LazyTensor(torch.Tensor):
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
local_tensor = DTensor(target, layout).local_tensor
return _convert_cls(self, local_tensor)
def clean(self) -> None:
@@ -537,10 +536,7 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
@@ -550,7 +546,7 @@ class LazyInitContext:
"""
def apply_fn(name: str, p: LazyTensor):
p.distribute(device_mesh, sharding_spec_dict[name])
p.distribute(layout_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)