mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 12:15:39 +00:00
[dtensor] fixed readme file name and removed deprecated file (#4162)
This commit is contained in:
parent
cc3cbe9f6f
commit
190a6ea9c2
@ -1,142 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.utils._pytree import tree_map
|
|
||||||
|
|
||||||
from .layout import Layout
|
|
||||||
from .layout_converter import LayoutConverter, to_global
|
|
||||||
from .sharding_spec import ShardingSpec
|
|
||||||
|
|
||||||
layout_converter = LayoutConverter()
|
|
||||||
|
|
||||||
|
|
||||||
class DTensor(torch.Tensor):
|
|
||||||
|
|
||||||
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
|
|
||||||
self.local_tensor = local_tensor
|
|
||||||
self.data_type = local_tensor.dtype
|
|
||||||
self.entire_shape = local_tensor.shape
|
|
||||||
self.dist_layout = dist_layout
|
|
||||||
self._apply_layout()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, local_tensor, layout):
|
|
||||||
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"DTensor({self.to_global()}, {self.dist_layout})"
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.__repr__()
|
|
||||||
|
|
||||||
def layout_convert(self, target_layout):
|
|
||||||
'''
|
|
||||||
Convert the layout of the tensor from source_spec to target_spec.
|
|
||||||
'''
|
|
||||||
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
|
|
||||||
self.dist_layout = target_layout
|
|
||||||
|
|
||||||
def _apply_layout(self):
|
|
||||||
'''
|
|
||||||
Apply the layout to the local tensor during initializing process.
|
|
||||||
'''
|
|
||||||
source_spec = construct_default_sharding_spec(self.local_tensor)
|
|
||||||
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
|
|
||||||
device_type=self.dist_layout.device_type,
|
|
||||||
sharding_spec=source_spec,
|
|
||||||
entire_shape=self.entire_shape)
|
|
||||||
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
def filter_arg(arg):
|
|
||||||
if isinstance(arg, DTensor):
|
|
||||||
return arg.local_tensor
|
|
||||||
else:
|
|
||||||
return arg
|
|
||||||
|
|
||||||
args = tree_map(filter_arg, args)
|
|
||||||
kwargs = tree_map(filter_arg, kwargs)
|
|
||||||
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
|
||||||
# and op type.
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device_mesh(self):
|
|
||||||
'''
|
|
||||||
Return the device mesh of the tensor.
|
|
||||||
'''
|
|
||||||
return self.dist_layout.device_mesh
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sharding_spec(self):
|
|
||||||
'''
|
|
||||||
Return the sharding specification of the tensor.
|
|
||||||
'''
|
|
||||||
return self.dist_layout.sharding_spec
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
'''
|
|
||||||
Move the tensor to a new device or convert the tensor to a new dtype.
|
|
||||||
'''
|
|
||||||
self.local_tensor = self.local_tensor.to(*args, **kwargs)
|
|
||||||
self.data_type = self.local_tensor.dtype
|
|
||||||
self.dist_layout.device_type = self.local_tensor.device
|
|
||||||
# TODO: update the device mesh process groups or we should just cache
|
|
||||||
# both the cpu process groups and the cuda process groups?
|
|
||||||
return self
|
|
||||||
|
|
||||||
def to_local(self):
|
|
||||||
'''
|
|
||||||
Return the local tensor in this rank.
|
|
||||||
'''
|
|
||||||
return self.local_tensor
|
|
||||||
|
|
||||||
def to_global(self):
|
|
||||||
'''
|
|
||||||
Recover the global tensor from the distributed tensor.
|
|
||||||
|
|
||||||
Note: This function will all_gather the local tensor to the global tensor and it
|
|
||||||
will not change the layout of the DTensor. This function is mainly used for debugging or
|
|
||||||
check the correctness of the distributed tensor.
|
|
||||||
'''
|
|
||||||
return to_global(self.local_tensor, self.dist_layout)
|
|
||||||
|
|
||||||
|
|
||||||
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
|
||||||
'''
|
|
||||||
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_tensor: tensor to be distributed.
|
|
||||||
dist_layout: the layout specification of the distributed tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A 'DTensor' object.
|
|
||||||
'''
|
|
||||||
return DTensor(local_tensor, dist_layout)
|
|
||||||
|
|
||||||
|
|
||||||
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
|
|
||||||
'''
|
|
||||||
This function converts all the parameters in the module to DTensor(DParam).
|
|
||||||
|
|
||||||
Note: This function is subject to future change as the DParam has not been implemented yet.
|
|
||||||
'''
|
|
||||||
for name, param in module.named_parameters():
|
|
||||||
if param is not None and not isinstance(param, DTensor):
|
|
||||||
# TODO: we could convert the parameter to DParam here,
|
|
||||||
# the type of the parameter could be an optional argument.
|
|
||||||
setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
|
||||||
'''
|
|
||||||
Construct the default sharding specification for the tensor.
|
|
||||||
'''
|
|
||||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
|
Loading…
Reference in New Issue
Block a user