mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
Revert "[sync] sync feature/shardformer with develop"
This commit is contained in:
@@ -16,66 +16,69 @@ def _all_gather(tensor, comm_spec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor, comm_spec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor, comm_spec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor, comm_spec, async_op=False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
def _mix_gather(tensor, comm_spec):
|
||||
@@ -411,7 +414,7 @@ class CommSpec:
|
||||
self.forward_only = forward_only
|
||||
if isinstance(self.logical_process_axis, list):
|
||||
if not mix_gather:
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten()
|
||||
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
|
||||
self.logical_process_axis = 0
|
||||
else:
|
||||
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
|
||||
|
@@ -1,103 +0,0 @@
|
||||
# 🔢 Distributed Tensor
|
||||
|
||||
## 📚 Table of Contents
|
||||
|
||||
- [🔢 Distributed Tensor](#-distributed-tensor)
|
||||
- [📚 Table of Contents](#-table-of-contents)
|
||||
- [🔗 Introduction](#-introduction)
|
||||
- [📝 Design](#-design)
|
||||
- [🔨 Usage](#-usage)
|
||||
- [🎈 Progress Log](#-progress-log)
|
||||
|
||||
## 🔗 Introduction
|
||||
|
||||
Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
|
||||
It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
|
||||
|
||||
## 📝 Design
|
||||
|
||||
Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
|
||||
|
||||
Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
|
||||
|
||||
|
||||
```text
|
||||
[1, 2, 3, 4 ]
|
||||
A = [4, 5, 6, 7 ]
|
||||
[8, 9, 10, 11]
|
||||
[12, 13, 14, 15]
|
||||
```
|
||||
|
||||
`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] |
|
||||
| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [8, 9, 10, 11] |
|
||||
| [12, 13, 14, 15] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology.
|
||||
|
||||
```text
|
||||
| --------------------—————————————————————-|
|
||||
| | |
|
||||
| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
| | |
|
||||
| [8, 9, 10, 11] | [12, 13, 14, 15] |
|
||||
| | |
|
||||
| --------------------——————————————————-----
|
||||
```
|
||||
|
||||
## 🔨 Usage
|
||||
|
||||
A sample API usage is given below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
# define a tensor
|
||||
a = torch.rand(16, 32).cuda()
|
||||
|
||||
# create sharding spec for the tensor
|
||||
# assume the sharding spec is [S0, R]
|
||||
dim_partition_dict = {0: [0]}
|
||||
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
|
||||
|
||||
# create a distributed tensor
|
||||
d_tensor = DTensor(a, device_mesh, sharding_spec)
|
||||
print(d_tensor)
|
||||
|
||||
global_tensor = d_tensor.to_global()
|
||||
print(global_tensor)
|
||||
```
|
||||
|
||||
|
||||
## 🎈 Progress Log
|
||||
|
||||
- [x] Support layout conversion
|
||||
- [x] Support sharding on 2D device mesh
|
||||
- [ ] Support sharding on 3D device mesh
|
||||
- [ ] Support sharding 4D device mesh
|
||||
- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)
|
@@ -1,4 +0,0 @@
|
||||
from .d_tensor import DTensor
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['DTensor', 'ShardingSpec']
|
||||
|
@@ -24,12 +24,12 @@ class CommSpec:
|
||||
'''
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
|
||||
to determine the buffer shape, and logical_process_axis
|
||||
|
||||
Argument:
|
||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
||||
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
@@ -37,7 +37,7 @@ class CommSpec:
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_group_dict: Dict,
|
||||
process_groups_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
@@ -45,7 +45,7 @@ class CommSpec:
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.process_group_dict = process_group_dict
|
||||
self.process_groups_dict = process_groups_dict
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["CommSpec:("]
|
||||
@@ -92,56 +92,68 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
@@ -257,7 +269,7 @@ class _AllToAll(torch.autograd.Function):
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
process_groups_dict=comm_spec.process_groups_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
|
@@ -3,119 +3,55 @@ from typing import Optional
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .layout import Layout
|
||||
from .layout_converter import LayoutConverter, to_global
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec']
|
||||
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
|
||||
class DTensor(torch.Tensor):
|
||||
"""
|
||||
DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information
|
||||
about the tensor distribution. The meta information includes the device mesh, the sharding specification,
|
||||
and the entire shape of the tensor.
|
||||
|
||||
During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the
|
||||
`DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank.
|
||||
In this way, all tensors involved in computation will only be native PyTorch tensors.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from colossalai.device import DeviceMesh
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
# define a tensor
|
||||
x = torch.rand(16, 32)
|
||||
|
||||
# create sharding spec for the tensor
|
||||
# assume the sharding spec is [S, R]
|
||||
dim_partition_dict = {
|
||||
0: 1
|
||||
}
|
||||
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
|
||||
|
||||
# create a distributed tensor
|
||||
d_tensor = DTensor(x, device_mesh, sharding_spec)
|
||||
```
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the unsharded tensor.
|
||||
device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec):
|
||||
# ensure this tensor is not a DTensor
|
||||
assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.'
|
||||
|
||||
# store meta info
|
||||
self.local_tensor = tensor
|
||||
self.data_type = tensor.dtype
|
||||
self.global_shape = tensor.shape
|
||||
|
||||
# create distributed layout
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
|
||||
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
|
||||
self.local_tensor = local_tensor
|
||||
self.data_type = local_tensor.dtype
|
||||
self.entire_shape = local_tensor.shape
|
||||
self.dist_layout = dist_layout
|
||||
|
||||
# shard the tensor
|
||||
self._apply_layout()
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor, *args, **kwargs):
|
||||
return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad)
|
||||
def __new__(cls, local_tensor, layout):
|
||||
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
|
||||
|
||||
def __repr__(self):
|
||||
return f"DTensor(\n{self.to_global()}\n{self.dist_layout}"
|
||||
return f"DTensor({self.to_global()}, {self.dist_layout})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
def layout_convert(self, target_layout):
|
||||
'''
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
Args:
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
|
||||
self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
|
||||
source_layout=self.dist_layout,
|
||||
target_layout=target_layout)
|
||||
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
|
||||
self.dist_layout = target_layout
|
||||
|
||||
def _apply_layout(self):
|
||||
'''
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
source_spec = construct_default_sharding_spec(self.local_tensor)
|
||||
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
|
||||
device_type=self.dist_layout.device_type,
|
||||
sharding_spec=source_spec,
|
||||
global_shape=self.global_shape)
|
||||
self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
|
||||
source_layout=source_layout,
|
||||
target_layout=self.dist_layout)
|
||||
entire_shape=self.entire_shape)
|
||||
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# convert all DTensors to native pytorch tensors
|
||||
# so that operations will be conducted on native tensors
|
||||
def filter_arg(arg):
|
||||
if isinstance(arg, DTensor):
|
||||
return arg.local_tensor
|
||||
@@ -124,9 +60,9 @@ class DTensor(torch.Tensor):
|
||||
|
||||
args = tree_map(filter_arg, args)
|
||||
kwargs = tree_map(filter_arg, kwargs)
|
||||
|
||||
# NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
||||
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
|
||||
# and op type.
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
@@ -149,6 +85,7 @@ class DTensor(torch.Tensor):
|
||||
'''
|
||||
self.local_tensor = self.local_tensor.to(*args, **kwargs)
|
||||
self.data_type = self.local_tensor.dtype
|
||||
self.dist_layout.device_type = self.local_tensor.device
|
||||
# TODO: update the device mesh process groups or we should just cache
|
||||
# both the cpu process groups and the cuda process groups?
|
||||
return self
|
||||
@@ -161,7 +98,7 @@ class DTensor(torch.Tensor):
|
||||
|
||||
def to_global(self):
|
||||
'''
|
||||
Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object.
|
||||
Recover the global tensor from the distributed tensor.
|
||||
|
||||
Note: This function will all_gather the local tensor to the global tensor and it
|
||||
will not change the layout of the DTensor. This function is mainly used for debugging or
|
||||
@@ -170,29 +107,24 @@ class DTensor(torch.Tensor):
|
||||
return to_global(self.local_tensor, self.dist_layout)
|
||||
|
||||
|
||||
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor:
|
||||
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
|
||||
'''
|
||||
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): tensor to be distributed.
|
||||
device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
|
||||
local_tensor: tensor to be distributed.
|
||||
dist_layout: the layout specification of the distributed tensor.
|
||||
|
||||
Returns:
|
||||
A 'DTensor' object.
|
||||
'''
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
return DTensor(local_tensor, dist_layout)
|
||||
|
||||
|
||||
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
|
||||
'''
|
||||
This function converts all the parameters in the module to DTensor(DParam).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): the module to be distributed.
|
||||
partition_fn (callable): the partition function which will be used to partition the parameters.
|
||||
|
||||
Note: This function is subject to future change as the DParam has not been implemented yet.
|
||||
'''
|
||||
for name, param in module.named_parameters():
|
||||
@@ -206,11 +138,5 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
|
||||
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the tensor to be sharded.
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
@@ -11,32 +11,28 @@ from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
class Layout:
|
||||
"""
|
||||
Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices.
|
||||
"""Layout of a tensor.
|
||||
|
||||
Args:
|
||||
device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed.
|
||||
sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded.
|
||||
global_shape (`torch.Size`): the entire shape of the global tensor.
|
||||
Attributes:
|
||||
device_mesh: the device mesh to store the tensor distributed.
|
||||
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||||
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||||
entire_shape: the entire shape of the global tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
|
||||
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
|
||||
entire_shape: torch.Size):
|
||||
self.device_mesh = device_mesh
|
||||
self.device_type = device_type
|
||||
self.sharding_spec = sharding_spec
|
||||
self.global_shape = global_shape
|
||||
self.entire_shape = entire_shape
|
||||
self._sanity_check()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.sharding_spec}')
|
||||
|
||||
def get_sharded_shape_per_device(self) -> torch.Size:
|
||||
"""
|
||||
Compute the shape of the sharded tensor on each device.
|
||||
|
||||
Returns:
|
||||
`torch.Size`: the shape of the sharded tensor on each device.
|
||||
"""
|
||||
sharded_shape = list(self.global_shape)
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
@@ -60,7 +56,7 @@ class Layout:
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
tensor_dim_size = self.global_shape[dim]
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
|
@@ -3,8 +3,10 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
@@ -26,21 +28,13 @@ class LayoutConverterOptions:
|
||||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor:
|
||||
"""
|
||||
Convert a distributed tensor to the global tensor with the given layout.
|
||||
This function returns a native `torch.Tensor` object.
|
||||
|
||||
|
||||
Args:
|
||||
distributed_tensor (`DTensor`): the distributed tensor to be converted.
|
||||
layout (`Layout`): the target layout specification.
|
||||
"""
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
global_shape=layout.global_shape)
|
||||
entire_shape=layout.entire_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
@@ -55,9 +49,6 @@ def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
|
||||
|
||||
class LayoutConverter(metaclass=SingletonMeta):
|
||||
"""
|
||||
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
@@ -100,14 +91,15 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
global_shape = (4, 4, 4)
|
||||
entire_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -120,12 +112,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
@@ -143,7 +130,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
process_groups_dict=process_groups_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
@@ -154,7 +141,8 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
@@ -179,14 +167,15 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
global_shape = (4, 4, 4)
|
||||
entire_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
rst_dict = layout_converter.all_to_all_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -199,12 +188,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
source_spec = source_layout.sharding_spec
|
||||
tensor_dims = source_spec.dims
|
||||
for f_index in range(tensor_dims - 1):
|
||||
@@ -245,7 +229,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
process_groups_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -268,7 +252,8 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -293,15 +278,16 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
global_shape = (4, 4, 4)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
rst_dict = layout_converter.shard_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -315,11 +301,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
|
||||
@@ -347,7 +329,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_group_dict=process_group_dict,
|
||||
process_groups_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -358,7 +340,8 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
global_shape=source_layout.global_shape)
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -416,7 +399,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
global_shape = (4, 4, 4)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
@@ -424,14 +407,16 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [R,S01,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# [S01,R,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
@@ -520,19 +505,21 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
global_shape = (4, 4, 4)
|
||||
entire_shape = (4, 4, 4)
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# [R,S0,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
global_shape=global_shape)
|
||||
entire_shape=entire_shape)
|
||||
|
||||
if rank in (0, 1):
|
||||
sharded_tensor_0 = torch.zeros(2, 1)
|
||||
@@ -567,4 +554,3 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
return tensor
|
||||
return tensor
|
||||
|
@@ -116,21 +116,21 @@ class DimSpec:
|
||||
|
||||
def dim_diff(self, other):
|
||||
'''
|
||||
The difference between two DimSpec.
|
||||
The difference between two _DimSpec.
|
||||
|
||||
Argument:
|
||||
other(DimSpec): the dim spec to compare with.
|
||||
other(_DimSpec): the dim spec to compare with.
|
||||
|
||||
Return:
|
||||
difference(int): the difference between two _DimSpec.
|
||||
|
||||
Example:
|
||||
```python
|
||||
dim_spec = DimSpec([0])
|
||||
other_dim_spec = DimSpec([0, 1])
|
||||
dim_spec = _DimSpec([0])
|
||||
other_dim_spec = _DimSpec([0, 1])
|
||||
print(dim_spec.difference(other_dim_spec))
|
||||
# output: 5
|
||||
```
|
||||
|
||||
Output:
|
||||
5
|
||||
'''
|
||||
difference = self.difference_dict[(str(self), str(other))]
|
||||
return difference
|
||||
@@ -142,13 +142,9 @@ class ShardingSpec:
|
||||
[R, R, S0, S1], which means
|
||||
|
||||
Argument:
|
||||
dim_size (int): The number of dimensions of the tensor to be sharded.
|
||||
dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
|
||||
and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None.
|
||||
E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1.
|
||||
sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
Generally, users should specify either dim_partition_dict or sharding_sequence.
|
||||
If both are given, users must ensure that they are consistent with each other. Defaults to None.
|
||||
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
|
||||
and the value of the key describe which logical axis will be sharded in that dimension.
|
||||
sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
@@ -212,7 +208,6 @@ class ShardingSpec:
|
||||
pair of sharding sequence.
|
||||
|
||||
Example:
|
||||
```python
|
||||
dim_partition_dict = {0: [0, 1]}
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
@@ -224,8 +219,10 @@ class ShardingSpec:
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
|
||||
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
|
||||
# output: 25
|
||||
```
|
||||
|
||||
Output:
|
||||
25
|
||||
|
||||
Argument:
|
||||
other(ShardingSpec): The ShardingSpec to compared with.
|
||||
|
||||
|
Reference in New Issue
Block a user