mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 04:35:17 +00:00
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
541 lines
18 KiB
Python
541 lines
18 KiB
Python
import copy
|
|
import operator
|
|
from functools import reduce
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
from colossalai.tensor.d_tensor.sharding_spec import DimSpec
|
|
|
|
from .layout import Layout
|
|
from .layout_converter import LayoutConverter
|
|
from .sharding_spec import ShardingSpec
|
|
|
|
layout_converter = LayoutConverter()
|
|
|
|
_SHARD_DIM = DimSpec([0])
|
|
|
|
|
|
def get_shard_dim_1d(p: torch.Tensor):
|
|
"""
|
|
Get the dimension along which the tensor is sharded, for example in 1D Tensor Parallel.
|
|
Args:
|
|
p (torch.Tensor): the input tensor
|
|
Returns:
|
|
int: the dimension along which the tensor is sharded
|
|
"""
|
|
if not is_distributed_tensor(p):
|
|
raise ValueError("p is not a distributed tensor")
|
|
sharding = p.dist_layout.sharding_spec.sharding_sequence
|
|
return sharding.index(_SHARD_DIM)
|
|
|
|
|
|
def clear_layout_converter():
|
|
global layout_converter
|
|
layout_converter.cached_solution.clear()
|
|
|
|
|
|
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
|
"""
|
|
Check whether the given tensor is a distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: Whether the given tensor is a distributed tensor.
|
|
"""
|
|
return hasattr(tensor, "dist_layout")
|
|
|
|
|
|
def is_sharded(dtensor: torch.Tensor) -> bool:
|
|
"""
|
|
Check if a tensor is sharded.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: True if the tensor is sharded, False otherwise.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
|
|
|
|
|
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be hijacked.
|
|
|
|
Returns:
|
|
torch.Tensor: The hijacked tensor.
|
|
"""
|
|
dtensor._old_detach = dtensor.detach
|
|
dtensor._old_clone = dtensor.clone
|
|
|
|
def new_detach(self):
|
|
t_ = self._old_detach()
|
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
|
return t_
|
|
|
|
def new_clone(self, *args, **kwargs):
|
|
t_ = self._old_clone(*args, **kwargs)
|
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
|
return t_
|
|
|
|
# bind the new methods to the tensor
|
|
dtensor.detach = new_detach.__get__(dtensor)
|
|
dtensor.clone = new_clone.__get__(dtensor)
|
|
return dtensor
|
|
|
|
|
|
def _construct_default_sharding_spec(
|
|
tensor: torch.Tensor,
|
|
) -> ShardingSpec:
|
|
"""
|
|
Construct the default sharding specification for the tensor.
|
|
|
|
Args:
|
|
tensor (`torch.Tensor`): the tensor to be sharded.
|
|
|
|
Returns:
|
|
A `ShardingSpec` object without any sharding specified.
|
|
"""
|
|
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
|
|
|
|
|
def _apply_layout(tensor, layout):
|
|
"""
|
|
Apply the layout to the local tensor during initializing process.
|
|
"""
|
|
# layout converter requires a source and target layout
|
|
# we construct the source layer for an unsharded tensor
|
|
# and use self.dist_layer as the target layout for the sharded tensor
|
|
source_spec = _construct_default_sharding_spec(tensor)
|
|
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
|
|
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
|
|
return sharded_tensor
|
|
|
|
|
|
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
|
"""
|
|
Convert the given tensor to a distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be converted.
|
|
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
|
|
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
|
|
|
# shard tensor
|
|
sharded_tensor = _apply_layout(tensor, dist_layout)
|
|
|
|
# hack some tensor methods
|
|
_hijack_detach_and_clone(sharded_tensor)
|
|
|
|
return sharded_tensor
|
|
|
|
|
|
def init_as_dtensor(
|
|
tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size
|
|
) -> torch.Tensor:
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
|
|
|
# shard tensor
|
|
tensor.dist_layout = dist_layout
|
|
|
|
# hack some tensor methods
|
|
_hijack_detach_and_clone(tensor)
|
|
|
|
return tensor
|
|
|
|
|
|
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
|
"""
|
|
Convert the layout of the tensor from source_spec to target_spec.
|
|
This will update the `local_tensor` and `dist_layout` in place.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
|
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
|
target_layout (Layout): the target layout specification.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
global_shape = get_global_shape(dtensor)
|
|
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
|
resharded_tensor = layout_converter.apply(
|
|
tensor=dtensor, source_layout=dtensor.dist_layout, target_layout=target_layout
|
|
)
|
|
return resharded_tensor
|
|
|
|
|
|
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Convert a distributed tensor to the global tensor with the given layout.
|
|
This function returns a native `torch.Tensor` object.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
|
|
|
Returns:
|
|
torch.Tensor: the global tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
layout_converter = LayoutConverter()
|
|
|
|
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
|
device_mesh = get_device_mesh(dtensor)
|
|
global_shape = get_global_shape(dtensor)
|
|
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
|
|
|
|
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
|
|
return global_tensor
|
|
|
|
|
|
def shard_rowwise(
|
|
tensor: torch.Tensor,
|
|
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Shard the first dim of the given tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be sharded.
|
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
|
If None, the tensor will be sharded with respect to the global process group.
|
|
Defaults to None.
|
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: The sharded tensor.
|
|
"""
|
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
|
if group_or_device_mesh is None:
|
|
group_or_device_mesh = dist.GroupMember.WORLD
|
|
|
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
|
else:
|
|
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
|
device_mesh = group_or_device_mesh
|
|
|
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
|
|
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
|
|
|
|
|
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
|
|
"""
|
|
Shard the first dim of the given tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be sharded.
|
|
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
|
|
If None, the tensor will be sharded with respect to the global process group.
|
|
Defaults to None.
|
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
|
|
|
Returns:
|
|
torch.Tensor: The sharded tensor.
|
|
"""
|
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
|
if group_or_device_mesh is None:
|
|
group_or_device_mesh = dist.GroupMember.WORLD
|
|
|
|
if isinstance(group_or_device_mesh, ProcessGroup):
|
|
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
|
|
else:
|
|
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
|
|
device_mesh = group_or_device_mesh
|
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
|
|
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
|
|
|
|
|
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
|
|
|
# make it distributed as well
|
|
param.dist_layout = dtensor.dist_layout
|
|
_hijack_detach_and_clone(param)
|
|
|
|
return param
|
|
|
|
|
|
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
param.data = dtensor
|
|
# make it distributed as well
|
|
param.dist_layout = dtensor.dist_layout
|
|
_hijack_detach_and_clone(param)
|
|
|
|
|
|
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
|
"""
|
|
Compute the global number of elements in the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
int: The global number of elements in the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
|
return numel
|
|
|
|
|
|
def get_layout(dtensor: torch.Tensor) -> Layout:
|
|
"""
|
|
Get the layout of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
Layout: The layout of the distributed tensor.
|
|
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout
|
|
|
|
|
|
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
|
"""
|
|
Get the global shape of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
torch.Size: The global shape of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.global_shape
|
|
|
|
|
|
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
|
"""
|
|
Get the device mesh of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
DeviceMesh: The device mesh of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.device_mesh
|
|
|
|
|
|
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
|
"""
|
|
Get the sharding spec of the distributed tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
ShardingSpec: The sharding spec of the distributed tensor.
|
|
"""
|
|
assert is_distributed_tensor(dtensor), "The input tensor is not a distributed tensor."
|
|
return dtensor.dist_layout.sharding_spec
|
|
|
|
|
|
# ======================================================
|
|
# Some sharding does not obey the SPMD style
|
|
# e.g. Fused QKV layer in GPT2
|
|
# we support customize sharding with the following APIs
|
|
# ======================================================
|
|
def is_customized_distributed_tensor(tensor: torch.Tensor):
|
|
"""
|
|
Check whether the given tensor is a customized distributed tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be checked.
|
|
|
|
Returns:
|
|
bool: Whether the given tensor is a customized distributed tensor.
|
|
"""
|
|
return hasattr(tensor, "shard_fn") and hasattr(tensor, "gather_fn")
|
|
|
|
|
|
def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be hijacked.
|
|
|
|
Returns:
|
|
torch.Tensor: The hijacked tensor.
|
|
"""
|
|
dtensor._old_detach = dtensor.detach
|
|
dtensor._old_clone = dtensor.clone
|
|
|
|
def new_detach(self):
|
|
t_ = self._old_detach()
|
|
t_.shard_fn = self.shard_fn
|
|
t_.gather_fn = self.gather_fn
|
|
return t_
|
|
|
|
def new_clone(self, *args, **kwargs):
|
|
t_ = self._old_clone(*args, **kwargs)
|
|
t_.shard_fn = self.shard_fn
|
|
t_.gather_fn = self.gather_fn
|
|
return t_
|
|
|
|
# bind the new methods to the tensor
|
|
dtensor.detach = new_detach.__get__(dtensor)
|
|
dtensor.clone = new_clone.__get__(dtensor)
|
|
return dtensor
|
|
|
|
|
|
def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
|
"""
|
|
Distribute the given tensor with the given shard_fn and gather_fn.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# define shard and gather functions
|
|
def shard_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
return tensor.chunk(world_size, dim=0)[rank]
|
|
|
|
def gather_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
torch.distributed.all_gather(shard_list, tensor)
|
|
return torch.cat(shard_list, dim=0)
|
|
|
|
# create a distributed tensor
|
|
tensor = torch.rand(4, 4)
|
|
dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn)
|
|
```
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be distributed.
|
|
shard_fn (callable): The function to shard the tensor.
|
|
gather_fn (callable): The function to gather the tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert callable(shard_fn), "The shard_fn must be callable."
|
|
assert callable(gather_fn), "The gather_fn must be callable."
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
|
|
sharded_tensor = shard_fn(tensor)
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
sharded_tensor.shard_fn = shard_fn
|
|
sharded_tensor.gather_fn = gather_fn
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor)
|
|
|
|
return sharded_tensor
|
|
|
|
|
|
def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
|
|
"""
|
|
Distribute the given tensor with the given shard_fn and gather_fn.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# define shard and gather functions
|
|
def shard_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
return tensor.chunk(world_size, dim=0)[rank]
|
|
|
|
def gather_fn(tensor):
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
torch.distributed.all_gather(shard_list, tensor)
|
|
return torch.cat(shard_list, dim=0)
|
|
|
|
# create a distributed tensor
|
|
tensor = torch.rand(4, 4)
|
|
dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
|
|
```
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to be distributed.
|
|
shard_fn (callable): The function to shard the tensor.
|
|
gather_fn (callable): The function to gather the tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The distributed tensor.
|
|
"""
|
|
assert callable(shard_fn), "The shard_fn must be callable."
|
|
assert callable(gather_fn), "The gather_fn must be callable."
|
|
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
tensor.shard_fn = shard_fn
|
|
tensor.gather_fn = gather_fn
|
|
|
|
# set the shard_fn and gather_fn as attributes of the distributed tensor
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
|
|
|
|
return tensor
|
|
|
|
|
|
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Gather the given tensor to the global tensor.
|
|
|
|
Args:
|
|
dtensor (torch.Tensor): The distributed tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The global tensor.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
return dtensor.gather_fn(dtensor)
|
|
|
|
|
|
def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
|
"""
|
|
Convert the given customized distributed tensor to a parameter.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
|
|
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
|
|
|
# make it distributed as well
|
|
param.shard_fn = dtensor.shard_fn
|
|
param.gather_fn = dtensor.gather_fn
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
|
return param
|
|
|
|
|
|
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
|
|
"""
|
|
Convert the given customized distributed tensor to an existing parameter.
|
|
"""
|
|
assert is_customized_distributed_tensor(dtensor), "The input tensor is not a customized distributed tensor."
|
|
|
|
param.data = dtensor.data
|
|
param.shard_fn = dtensor.shard_fn
|
|
param.gather_fn = dtensor.gather_fn
|
|
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|