mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish code
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .d_tensor import DTensor
|
||||
from .layout import Layout
|
||||
from .layout_converter import LayoutConverter
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
def shard_rowwise(tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
inplace: bool = False) -> DTensor:
|
||||
|
||||
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, "dist_layout")
|
||||
|
||||
|
||||
def is_sharded(dtensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check if a tensor is sharded.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the tensor is sharded, False otherwise.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
||||
|
||||
|
||||
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
dtensor._old_detach = dtensor.detach
|
||||
dtensor._old_clone = dtensor.clone
|
||||
|
||||
def new_detach(self):
|
||||
t_ = self._old_detach()
|
||||
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||
return t_
|
||||
|
||||
def new_clone(self, *args, **kwargs):
|
||||
t_ = self._old_clone(*args, **kwargs)
|
||||
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||
return t_
|
||||
|
||||
# bind the new methods to the tensor
|
||||
dtensor.detach = new_detach.__get__(dtensor)
|
||||
dtensor.clone = new_clone.__get__(dtensor)
|
||||
return dtensor
|
||||
|
||||
|
||||
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the tensor to be sharded.
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
||||
|
||||
def _apply_layout(tensor, layout):
|
||||
'''
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
source_spec = _construct_default_sharding_spec(tensor)
|
||||
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
|
||||
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Convert the given tensor to a distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be converted.
|
||||
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
|
||||
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
||||
|
||||
# shard tensor
|
||||
sharded_tensor = _apply_layout(tensor, dist_layout)
|
||||
|
||||
# hack some tensor methods
|
||||
_hijack_detach_and_clone(sharded_tensor)
|
||||
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
'''
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
global_shape = get_global_shape(dtensor)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
resharded_tensor = layout_converter.apply(tensor=dtensor,
|
||||
source_layout=dtensor.dist_layout,
|
||||
target_layout=target_layout)
|
||||
return resharded_tensor
|
||||
|
||||
|
||||
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a distributed tensor to the global tensor with the given layout.
|
||||
This function returns a native `torch.Tensor` object.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the global tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
||||
device_mesh = get_device_mesh(dtensor)
|
||||
global_shape = get_global_shape(dtensor)
|
||||
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
|
||||
|
||||
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def shard_rowwise(
|
||||
tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor.
|
||||
|
||||
@@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor,
|
||||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DTensor: The sharded tensor.
|
||||
torch.Tensor: The sharded tensor.
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
@@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor,
|
||||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
|
||||
if not inplace:
|
||||
tensor = tensor.detach().clone()
|
||||
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
|
||||
def shard_colwise(tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
inplace: bool = False) -> DTensor:
|
||||
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor.
|
||||
|
||||
@@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor,
|
||||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DTensor: The sharded tensor.
|
||||
torch.Tensor: The sharded tensor.
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
@@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor,
|
||||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
|
||||
if not inplace:
|
||||
tensor = tensor.detach().clone()
|
||||
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
_hijack_detach_and_clone(param)
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Compute the global number of elements in the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
int: The global number of elements in the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
||||
return numel
|
||||
|
||||
|
||||
def get_layout(dtensor: torch.Tensor) -> Layout:
|
||||
"""
|
||||
Get the layout of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
Layout: The layout of the distributed tensor.
|
||||
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout
|
||||
|
||||
|
||||
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
||||
"""
|
||||
Get the global shape of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
torch.Size: The global shape of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.global_shape
|
||||
|
||||
|
||||
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
||||
"""
|
||||
Get the device mesh of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: The device mesh of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.device_mesh
|
||||
|
||||
|
||||
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||
"""
|
||||
Get the sharding spec of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
ShardingSpec: The sharding spec of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
Reference in New Issue
Block a user