[dtensor] polish sharding spec docstring (#3838)

* [dtensor] polish sharding spec docstring

* [dtensor] polish sharding spec example docstring
This commit is contained in:
Hongxin Liu 2023-05-25 13:09:42 +08:00 committed by GitHub
parent 34966378e8
commit 7c9f2ed6dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -116,21 +116,21 @@ class DimSpec:
def dim_diff(self, other):
'''
The difference between two _DimSpec.
The difference between two DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
other(DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
```python
dim_spec = DimSpec([0])
other_dim_spec = DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
# output: 5
```
'''
difference = self.difference_dict[(str(self), str(other))]
return difference
@ -142,9 +142,13 @@ class ShardingSpec:
[R, R, S0, S1], which means
Argument:
dim_size (int): The number of dimensions of the tensor to be sharded.
dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key describe which logical axis will be sharded in that dimension.
and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None.
E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1.
sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
Generally, users should specify either dim_partition_dict or sharding_sequence.
If both are given, users must ensure that they are consistent with each other. Defaults to None.
'''
def __init__(self,
@ -208,6 +212,7 @@ class ShardingSpec:
pair of sharding sequence.
Example:
```python
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
@ -219,10 +224,8 @@ class ShardingSpec:
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
# output: 25
```
Argument:
other(ShardingSpec): The ShardingSpec to compared with.