mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[dtensor] updated api and doc (#3845)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -8,8 +8,9 @@ from torch import Tensor
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
@@ -172,7 +173,7 @@ class LazyTensor(torch.Tensor):
|
||||
self.clean()
|
||||
return _convert_cls(self, target)
|
||||
|
||||
def distribute(self, layout: Layout) -> torch.Tensor:
|
||||
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
|
||||
|
||||
Args:
|
||||
@@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
|
||||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
@@ -536,7 +537,10 @@ class LazyInitContext:
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
@staticmethod
|
||||
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
|
||||
def distribute(module: nn.Module,
|
||||
device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: Dict[str, ShardingSpec],
|
||||
verbose: bool = False) -> nn.Module:
|
||||
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
|
||||
|
||||
Args:
|
||||
@@ -546,7 +550,7 @@ class LazyInitContext:
|
||||
"""
|
||||
|
||||
def apply_fn(name: str, p: LazyTensor):
|
||||
p.distribute(layout_dict[name])
|
||||
p.distribute(device_mesh, sharding_spec_dict[name])
|
||||
|
||||
return _apply_to_lazy_module(module, apply_fn, verbose)
|
||||
|
||||
|
Reference in New Issue
Block a user